Compare commits

..

74 Commits

Author SHA1 Message Date
Bai
a474d9be3e perf: serialize_nodes/assets as tree-node, if not nodes or assets, return 2025-12-24 16:52:24 +08:00
Bai
3a4e93af2f refactor: finished AssetTree API. support gloabl and real org; support asset-tree of AssetPage and PermPage; support search tree re-initial 2025-12-24 15:45:39 +08:00
Bai
9c2ddbba7e refactor: while search tree, re-initial ZTree use API data. 2025-12-23 18:51:26 +08:00
Bai
39129cecbe refactor: finished NodeChildrenAsTreeApi. But, need TODO support GLOBAL org logic. 2025-12-23 18:46:18 +08:00
Bai
88819bbf26 perf: Modify AssetViewSet filter by node 2025-12-22 17:13:05 +08:00
Bai
a88e35156a perf: AssetTree if with_assets pre fetch asset attrs 2025-12-22 16:51:08 +08:00
Bai
22a27946a7 perf: AssetTree support with_assets and full_tree kwargs 2025-12-22 16:32:48 +08:00
Bai
4983465a23 perf: UserPermTree support with_assets params 2025-12-22 13:28:32 +08:00
Bai
4d9fc9dfd6 perf: UserPermUtil supoort get_node_assets and get_node_all_assets 2025-12-22 12:32:27 +08:00
Bai
c7cb83fa1d perf: split UserPermUtil from UserPermTree 2025-12-22 11:23:13 +08:00
Bai
ee92c72b50 perf: add UserPermTree, finished 2025-12-19 19:53:08 +08:00
Bai
6a05fbe0fe perf: add AssetSearchTree, move remove_zero_assets_node from asset-tree to asset-search-tree 2025-12-19 15:19:54 +08:00
Bai
0284be169a perf: add AssetSearchTree, modify Node-Model get_all_assets function use node not nodes 2025-12-19 14:44:17 +08:00
Bai
a4e9d4f815 perf: add AssetSearchTree, supported: category search; modify fake generate asset set node_id 2025-12-19 14:25:16 +08:00
Bai
bbe549696a perf: add AssetSearchTree, not yet supported: category search 2025-12-19 10:54:07 +08:00
Bai
56f720271a refactor: format tree.print 2025-12-18 18:30:36 +08:00
Bai
9755076f7f refactor: add tree.py and asset_tree.py, finished build AssetTree. 2025-12-18 15:37:18 +08:00
Bai
8d7abef191 perf: add migrations - migrate asset node_id field 2025-12-16 18:50:15 +08:00
Bai
aaa40722c4 perf: add util - cleanup and kepp one node for Multi-Parent-Nodes Assets and generate report 2025-12-16 16:29:24 +08:00
Bai
ca39344937 perf: add util - cleanup and kepp one node for Multi-Parent-Nodes Assets and generate report 2025-12-16 16:28:37 +08:00
Bai
4b9a8227c9 perf: add util - find Multi-Parent Assets and generate report 2025-12-16 15:32:41 +08:00
feng
f362163af1 perf: remove gpt model 2025-12-16 13:19:45 +08:00
fit2bot
5f1ba56e56 Merge pull request #16094 from jumpserver/pr@dev@chat_model
perf: Add open ui chat model
2025-12-10 10:43:14 +08:00
Chenyang Shen
2b1fdb937b Merge pull request #16404 from jumpserver/pr@dev@feat_reset_key_store
feat: reset piico device after open device
2025-12-09 15:16:41 +08:00
Aaron3S
1e754546f1 feat: reset piico device after open device 2025-12-09 14:47:37 +08:00
Bai
2ec71feafc perf: rbac oauth2_provider perms i18n 2025-12-09 10:17:34 +08:00
Bai
02e8905330 perf: redirect/confirm page and i18n 2025-12-08 18:43:04 +08:00
Bai
8d68f5589b perf: redirect/confirm page and i18n 2025-12-08 18:43:04 +08:00
Bai
4df13fc384 perf: redirect/confirm page and i18n 2025-12-08 18:40:12 +08:00
Bai
78c1162028 perf: when DEBUG_DEV=True, allow OAUTH2_PROVIDER redirect_url localhost listen 2025-12-08 16:42:07 +08:00
Bai
14c2512b45 fix: accesskey authentication user is None error 2025-12-08 15:06:47 +08:00
Bai
d6d7072da5 perf: request.GET.copy() to dict(), because copy() returned values is list [] 2025-12-08 12:50:49 +08:00
fit2bot
993bc36c5e perf: handling the next parameter propagation issue in third-party authentication flows (#16395)
* perf: remove call client old- method via ?next=client

* feat: add 2 decorators for login-get and login-callback-get to set next_page and get next_page from session

* perf: code style

* perf: handling the next parameter propagation issue in third-party authentication flows

* perf: request.GET.dict() to copy()

* perf: style import

---------

Co-authored-by: Bai <baijiangjie@gmail.com>
2025-12-08 12:34:32 +08:00
fit2bot
ecff2ea07e perf: move oauth2_provider api auth_backend to the end, and while accesstoken_backend not user do not raise execption, go on next bakcned auth (#16393)
* perf: move oauth2_provider api auth_backend to the end, and while accesstoken_backend not user do not raise execption, go on next bakcned auth

* perf: re-sorted DEFAULT_AUTHENTICATION_CLASSES

---------

Co-authored-by: Bai <baijiangjie@gmail.com>
2025-12-08 09:57:17 +08:00
fit2bot
ba70edf221 perf: when oauth2 application delete expired well-known page cache via post_delete signal (#16392)
Co-authored-by: Bai <baijiangjie@gmail.com>
2025-12-08 09:54:18 +08:00
Bai
50050dff57 fix: cas only allow exist user login 2025-12-04 18:37:54 +08:00
jiangweidong
944226866c perf: Add a diff field to operate-log export 2025-12-04 18:01:01 +08:00
fit2bot
fe13221d88 fix: Improve server URI validation and connection testing in LDAP module (#16377)
Co-authored-by: wangruidong <940853815@qq.com>
2025-12-04 17:59:01 +08:00
fit2bot
ba17863892 perf: Remove unused CAS user exception handling and simplify login view error response (#16380)
* perf: Remove unused CAS user exception handling and simplify login view error response

* perf: position code

---------

Co-authored-by: wangruidong <940853815@qq.com>
Co-authored-by: Bai <baijiangjie@gmail.com>
2025-12-04 17:49:58 +08:00
fit2bot
065bfeda52 fix: only exists user login maybe invalid (#16379)
* fix: only exists user login maybe invalid

* fix: only exists user login maybe invalid

* fix: only exists user login maybe invalid

---------

Co-authored-by: Bai <baijiangjie@gmail.com>
2025-12-04 16:18:47 +08:00
wangruidong
04af26500a fix: Allow login with username or email for existing users 2025-12-04 10:04:32 +08:00
fit2bot
e0388364c3 fix: use third part authentication service rediect to client failed (#16370)
* perf: .well-known cached 1h and support saml2 redirect_to client

* fix: support wecom redirect_to client (reslove wecom waf 501 error)

* fix: support oauth2 auth rediect to client

* fix: safe next url

---------

Co-authored-by: Bai <baijiangjie@gmail.com>
2025-12-03 19:07:00 +08:00
Bai
3c96480b0c perf: add manage.py command: init_oauth2_provider, resolve init jumpserver client failed issue 2025-12-03 14:37:20 +08:00
Bai
95331a0c4b perf: redirect to client show tips 2025-12-02 18:39:48 +08:00
Bai
b8ecb703cf perf: url revoke_token/ to revoke/ 2025-12-02 18:21:13 +08:00
Bai
1a3f5e3f9a perf: default access token/refresh token expired at 1h/7day 2025-12-02 15:34:55 +08:00
Bai
854396e8d5 perf: access-token api 2025-12-02 15:25:55 +08:00
Bai
ab08603e66 perf: organize oauth2_provider urls, add .well-known API 2025-12-02 14:55:09 +08:00
Bai
427fd3f72c perf: organize oauth2_provider urls, add .well-known API 2025-12-02 14:55:09 +08:00
Bai
0aba9ba120 perf: hide the unused URLs in OAuth2 provider 2025-12-02 14:55:09 +08:00
Bai
045ca8807a feat: modify client redirect url 2025-12-01 19:04:19 +08:00
Bai
19a68d8930 feat: add api access token 2025-12-01 17:55:08 +08:00
Bai
75ed02a2d2 feat: add oauth2 provider accesstokens api 2025-12-01 17:55:08 +08:00
fit2bot
f420dac49c feat: Host cloud sync supports state cloud - i18n (#16304)
Co-authored-by: jiangweidong <1053570670@qq.com>
Co-authored-by: Jiangjie Bai <jiangjie.bai@fit2cloud.com>
2025-12-01 10:56:14 +08:00
Bai
1ee68134f2 fix: rename utils methond 2025-12-01 10:41:14 +08:00
Bai
937265db5d perf: add period task clear oauth2 provider expired tokens 2025-12-01 10:41:14 +08:00
Bai
c611d5e88b perf: add utils delete oauth2 provider application 2025-12-01 10:41:14 +08:00
Bai
883b6b6383 perf: skip_authorization for redirect to jms client 2025-12-01 10:41:14 +08:00
Bai
ac4c72064f perf: register jumpserver client logic 2025-12-01 10:41:14 +08:00
Bai
dbf8360e27 feat: add OAUTH2_PROVIDER_ACCESS_TOKEN_EXPIRE_SECONDS 2025-12-01 10:41:14 +08:00
github-actions[bot]
150d7a09bc perf: Update Dockerfile with new base image tag 2025-11-28 16:28:23 +08:00
Bai
a7ed20e059 perf: support as oauth2 provider 2025-11-28 16:28:23 +08:00
github-actions[bot]
1b7b8e6f2e perf: Update Dockerfile with new base image tag 2025-11-28 16:28:23 +08:00
Bai
cd22fbce19 perf: support as oauth2 provider 2025-11-28 16:28:23 +08:00
老广
c191d86f43 Refactor GitHub Actions workflow for event handling 2025-11-27 14:27:27 +08:00
wangruidong
7911137ffb fix: Truncate asset URL to 128 characters to prevent exceeding length limit 2025-11-27 14:17:19 +08:00
wangruidong
1053933cae fix: Add migration to refresh PostgreSQL collation version 2025-11-27 14:16:44 +08:00
wangruidong
96fdc025cd fix: Search for risk_level, search result is empty 2025-11-26 18:07:20 +08:00
wangruidong
fde19764e0 fix: Processing redirection url unquote 2025-11-25 14:00:31 +08:00
wangruidong
978fbc70e6 perf: Improve city retrieval fallback to handle missing values 2025-11-25 13:59:48 +08:00
Ewall555
636ffd786d feat: add namespace setting to k8s protocol configuration 2025-11-25 11:08:23 +08:00
feng
3b756aa26f perf: Component i18n lang lower 2025-11-25 10:56:37 +08:00
Bai
817c0099d1 perf: client pkg rename 2025-11-21 18:45:49 +08:00
Bai
a0d7871130 perf: client pkg rename 2025-11-21 18:45:49 +08:00
116 changed files with 11454 additions and 6322 deletions

View File

@@ -1,10 +1,33 @@
on: [push, pull_request, release]
on:
push:
pull_request:
types: [opened, synchronize, closed]
release:
types: [created]
name: JumpServer repos generic handler
jobs:
generic_handler:
name: Run generic handler
handle_pull_request:
if: github.event_name == 'pull_request'
runs-on: ubuntu-latest
steps:
- uses: jumpserver/action-generic-handler@master
env:
GITHUB_TOKEN: ${{ secrets.PRIVATE_TOKEN }}
I18N_TOKEN: ${{ secrets.I18N_TOKEN }}
handle_push:
if: github.event_name == 'push'
runs-on: ubuntu-latest
steps:
- uses: jumpserver/action-generic-handler@master
env:
GITHUB_TOKEN: ${{ secrets.PRIVATE_TOKEN }}
I18N_TOKEN: ${{ secrets.I18N_TOKEN }}
handle_release:
if: github.event_name == 'release'
runs-on: ubuntu-latest
steps:
- uses: jumpserver/action-generic-handler@master

View File

@@ -1,4 +1,4 @@
FROM jumpserver/core-base:20251113_092612 AS stage-build
FROM jumpserver/core-base:20251128_025056 AS stage-build
ARG VERSION

View File

@@ -5,6 +5,7 @@ from rest_framework.request import Request
from assets.models import Node, Platform, Protocol, MyAsset
from assets.utils import get_node_from_request, is_query_node_all_assets
from common.utils import lazyproperty, timeit
from assets.tree.asset_tree import AssetTreeNode
class SerializeToTreeNodeMixin:
@@ -19,22 +20,22 @@ class SerializeToTreeNodeMixin:
return False
@timeit
def serialize_nodes(self, nodes: List[Node], with_asset_amount=False):
if with_asset_amount:
def _name(node: Node):
return '{} ({})'.format(node.value, node.assets_amount)
else:
def _name(node: Node):
return node.value
def serialize_nodes(self, nodes: List[AssetTreeNode], with_asset_amount=False, expand_level=1, with_assets=False):
if not nodes:
return []
def _open(node):
if not self.is_sync:
# 异步加载资产树时,默认展开节点
return True
if not node.parent_key:
return True
def _name(node: AssetTreeNode):
v = node.value
if not with_asset_amount:
return v
v = f'{v} ({node.assets_amount_total})'
return v
def is_parent(node: AssetTreeNode):
if with_assets:
return node.assets_amount > 0 or not node.is_leaf
else:
return False
return not node.is_leaf
data = [
{
@@ -42,15 +43,17 @@ class SerializeToTreeNodeMixin:
'name': _name(node),
'title': _name(node),
'pId': node.parent_key,
'isParent': True,
'open': _open(node),
'isParent': is_parent(node),
'open': node.level <= expand_level,
'meta': {
'type': 'node',
'data': {
"id": node.id,
"key": node.key,
"value": node.value,
"assets_amount": node.assets_amount,
"assets_amount_total": node.assets_amount_total,
},
'type': 'node'
}
}
for node in nodes
@@ -72,6 +75,9 @@ class SerializeToTreeNodeMixin:
@timeit
def serialize_assets(self, assets, node_key=None, get_pid=None):
if not assets:
return []
if not get_pid and not node_key:
get_pid = lambda asset, platform: getattr(asset, 'parent_key', '')

View File

@@ -1,6 +1,6 @@
# ~*~ coding: utf-8 ~*~
from django.db.models import Q
from django.db.models import Q, Count
from django.utils.translation import gettext_lazy as _
from rest_framework.generics import get_object_or_404
from rest_framework.response import Response
@@ -11,12 +11,16 @@ from common.tree import TreeNodeSerializer
from common.utils import get_logger
from orgs.mixins import generics
from orgs.utils import current_org
from orgs.models import Organization
from .mixin import SerializeToTreeNodeMixin
from .. import serializers
from ..const import AllTypes
from ..models import Node, Platform, Asset
from assets.tree.asset_tree import AssetTree
logger = get_logger(__file__)
__all__ = [
'NodeChildrenApi',
'NodeChildrenAsTreeApi',
@@ -25,14 +29,13 @@ __all__ = [
class NodeChildrenApi(generics.ListCreateAPIView):
"""
节点的增删改查
"""
''' 节点的增删改查 '''
serializer_class = serializers.NodeSerializer
search_fields = ('value',)
instance = None
is_initial = False
perm_model = Node
def initial(self, request, *args, **kwargs):
super().initial(request, *args, **kwargs)
@@ -65,42 +68,16 @@ class NodeChildrenApi(generics.ListCreateAPIView):
else:
node = Node.org_root()
return node
if pk:
node = get_object_or_404(Node, pk=pk)
else:
node = get_object_or_404(Node, key=key)
return node
def get_org_root_queryset(self, query_all):
if query_all:
return Node.objects.all()
else:
return Node.org_root_nodes()
def get_queryset(self):
query_all = self.request.query_params.get("all", "0") == "all"
if self.is_initial and current_org.is_root():
return self.get_org_root_queryset(query_all)
if self.is_initial:
with_self = True
else:
with_self = False
if not self.instance:
return Node.objects.none()
if query_all:
queryset = self.instance.get_all_children(with_self=with_self)
else:
queryset = self.instance.get_children(with_self=with_self)
return queryset
class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
"""
节点子节点作为树返回,
''' 节点子节点作为树返回,
[
{
"id": "",
@@ -109,51 +86,96 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
"meta": ""
}
]
'''
"""
model = Node
def filter_queryset(self, queryset):
""" queryset is Node queryset """
if not self.request.GET.get('search'):
return queryset
queryset = super().filter_queryset(queryset)
queryset = self.model.get_ancestor_queryset(queryset)
return queryset
def get_queryset_for_assets(self):
query_all = self.request.query_params.get("all", "0") == "all"
include_assets = self.request.query_params.get('assets', '0') == '1'
if not self.instance or not include_assets:
return Asset.objects.none()
if not self.request.GET.get('search') and self.instance.is_org_root():
return Asset.objects.none()
if query_all:
assets = self.instance.get_all_assets()
else:
assets = self.instance.get_assets()
return assets.only(
"id", "name", "address", "platform_id",
"org_id", "is_active", 'comment'
).prefetch_related('platform')
def filter_queryset_for_assets(self, assets):
search = self.request.query_params.get('search')
if search:
q = Q(name__icontains=search) | Q(address__icontains=search)
assets = assets.filter(q)
return assets
def list(self, request, *args, **kwargs):
nodes = self.filter_queryset(self.get_queryset()).order_by('value')
search = request.query_params.get('search')
with_assets = request.query_params.get('assets', '0') == '1'
with_asset_amount = request.query_params.get('asset_amount', '1') == '1'
nodes = self.serialize_nodes(nodes, with_asset_amount=with_asset_amount)
assets = self.filter_queryset_for_assets(self.get_queryset_for_assets())
node_key = self.instance.key if self.instance else None
assets = self.serialize_assets(assets, node_key=node_key)
with_asset_amount = True
nodes, assets, expand_level = self.get_nodes_assets(search, with_assets)
nodes = self.serialize_nodes(nodes, with_asset_amount=with_asset_amount, expand_level=expand_level)
assets = self.serialize_assets(assets)
data = [*nodes, *assets]
return Response(data=data)
def get_nodes_assets(self, search, with_assets):
#
# 资产管理-节点树
#
# 全局组织: 初始化节点树, 返回所有节点, 不包含资产, 不展开节点
# 实体组织: 初始化节点树, 返回所有节点, 不包含资产, 展开一级节点
# 前端搜索
if not with_assets:
if current_org.is_root():
orgs = Organization.objects.all()
expand_level = 0
else:
orgs = [current_org]
expand_level = 1
nodes = []
assets = []
for org in orgs:
tree = AssetTree(org=org)
org_nodes = tree.get_nodes()
nodes.extend(org_nodes)
return nodes, assets, expand_level
#
# 权限管理、账号发现、风险检测 - 资产节点树
#
# 全局组织: 搜索资产, 生成资产节点树, 过滤每个组织前 1000 个资产, 展开所有节点
# 实体组织: 搜索资产, 生成资产节点树, 过滤前 1000 个资产, 展开所有节点
if search:
if current_org.is_root():
orgs = list(Organization.objects.all())
else:
orgs = [current_org]
nodes = []
assets = []
assets_q_object = Q(name__icontains=search) | Q(address__icontains=search)
with_assets_limit = 1000 / len(orgs)
for org in orgs:
tree = AssetTree(
assets_q_object=assets_q_object, org=org,
with_assets=True, with_assets_limit=with_assets_limit, full_tree=False
)
nodes.extend(tree.get_nodes())
assets.extend(tree.get_assets())
expand_level = 10000 # search 时展开所有节点
return nodes, assets, expand_level
# 全局组织: 展开某个节点及其资产
# 实体组织: 展开某个节点及其资产
# 实体组织: 初始化资产节点树, 自动展开根节点及其资产, 所以节点要包含自己 (特殊情况)
if self.instance:
nodes = []
tree = AssetTree(with_assets_node_id=self.instance.id, org=self.instance.org)
nodes_with_self = False
if not current_org.is_root() and self.instance.is_org_root():
nodes_with_self = True
nodes = tree.get_node_children(key=self.instance.key, with_self=nodes_with_self)
assets = tree.get_assets()
expand_level = 1 # 默认只展开第一级
return nodes, assets, expand_level
# 全局组织: 初始化资产节点树, 仅返回各组织根节点, 不展开
orgs = Organization.objects.all()
nodes = []
assets = []
for org in orgs:
tree = AssetTree(org=org, with_assets=False)
if not tree.root:
continue
nodes.append(tree.root)
expand_level = 0 # 默认不展开节点
return nodes, assets, expand_level
class CategoryTreeApi(SerializeToTreeNodeMixin, generics.ListAPIView):
serializer_class = TreeNodeSerializer

View File

@@ -268,6 +268,14 @@ class Protocol(ChoicesMixin, models.TextChoices):
'port_from_addr': True,
'required': True,
'secret_types': ['token'],
'setting': {
'namespace': {
'type': 'str',
'required': False,
'default': '',
'label': _('Namespace')
}
}
},
cls.http: {
'port': 80,

View File

@@ -63,11 +63,11 @@ class NodeFilterBackend(filters.BaseFilterBackend):
query_all = is_query_node_all_assets(request)
if query_all:
return queryset.filter(
Q(nodes__key__startswith=f'{node.key}:') |
Q(nodes__key=node.key)
Q(node__key__startswith=f'{node.key}:') |
Q(node__key=node.key)
).distinct()
else:
return queryset.filter(nodes__key=node.key).distinct()
return queryset.filter(node__key=node.key).distinct()
class IpInFilterBackend(filters.BaseFilterBackend):

View File

@@ -0,0 +1,126 @@
# Generated by Django 4.1.13 on 2025-12-16 09:14
from django.db import migrations, models, transaction
import django.db.models.deletion
def log(msg=''):
print(f' -> {msg}')
def ensure_asset_single_node(apps, schema_editor):
print('')
log('Checking that all assets are linked to only one node...')
Asset = apps.get_model('assets', 'Asset')
Through = Asset.nodes.through
assets_count_multi_nodes = Through.objects.values('asset_id').annotate(
node_count=models.Count('node_id')
).filter(node_count__gt=1).count()
if assets_count_multi_nodes > 0:
raise Exception(
f'There are {assets_count_multi_nodes} assets associated with more than one node. '
'Please ensure each asset is linked to only one node before applying this migration.'
)
else:
log('All assets are linked to only one node. Proceeding with the migration.')
def ensure_asset_has_node(apps, schema_editor):
log('Checking that all assets are linked to at least one node...')
Asset = apps.get_model('assets', 'Asset')
Through = Asset.nodes.through
asset_count = Asset.objects.count()
through_asset_count = Through.objects.values('asset_id').count()
assets_count_without_node = asset_count - through_asset_count
if assets_count_without_node > 0:
raise Exception(
f'Some assets ({assets_count_without_node}) are not associated with any node. '
'Please ensure all assets are linked to a node before applying this migration.'
)
else:
log('All assets are linked to a node. Proceeding with the migration.')
def migrate_asset_node_id_field(apps, schema_editor):
log('Migrating node_id field for all assets...')
Asset = apps.get_model('assets', 'Asset')
Through = Asset.nodes.through
assets = Asset.objects.filter(node_id__isnull=True)
log (f'Found {assets.count()} assets to migrate.')
asset_node_mapper = {
str(asset_id): str(node_id)
for asset_id, node_id in Through.objects.values_list('asset_id', 'node_id')
}
# 测试
asset_node_mapper.pop(None, None) # Remove any entries with None keys
for asset in assets:
node_id = asset_node_mapper.get(str(asset.id))
if not node_id:
raise Exception(
f'Asset (ID: {asset.id}) is not associated with any node. '
'Cannot migrate node_id field.'
)
asset.node_id = node_id
with transaction.atomic():
total = len(assets)
batch_size = 5000
for i in range(0, total, batch_size):
batch = assets[i:i+batch_size]
start = i + 1
end = min(i + batch_size, total)
for asset in batch:
asset.save(update_fields=['node_id'])
log(f"Migrated {start}-{end}/{total} assets")
count = Asset.objects.filter(node_id__isnull=True).count()
if count > 0:
log('Warning: Some assets still have null node_id after migration.')
raise Exception('Migration failed: Some assets have null node_id.')
count = Asset.objects.filter(node_id__isnull=False).count()
log(f'Successfully migrated node_id for {count} assets.')
class Migration(migrations.Migration):
dependencies = [
('assets', '0019_alter_asset_connectivity'),
]
operations = [
migrations.RunPython(
ensure_asset_single_node,
reverse_code=migrations.RunPython.noop
),
migrations.RunPython(
ensure_asset_has_node,
reverse_code=migrations.RunPython.noop
),
migrations.AddField(
model_name='asset',
name='node',
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, related_name='direct_assets', to='assets.node', verbose_name='Node'),
),
migrations.RunPython(
migrate_asset_node_id_field,
reverse_code=migrations.RunPython.noop
),
migrations.AlterField(
model_name='asset',
name='node',
field=models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='direct_assets', to='assets.node', verbose_name='Node'),
),
]

View File

@@ -172,6 +172,11 @@ class Asset(NodesRelationMixin, LabeledMixin, AbsConnectivity, JSONFilterMixin,
"assets.Zone", null=True, blank=True, related_name='assets',
verbose_name=_("Zone"), on_delete=models.SET_NULL
)
node = models.ForeignKey(
'assets.Node', null=False, blank=False, on_delete=models.PROTECT,
related_name='direct_assets', verbose_name=_("Node")
)
# TODO: 删除完代码中所有使用的地方后,再删除 nodes 字段,并将 node 字段的 related_name 改为 'assets'
nodes = models.ManyToManyField(
'assets.Node', default=default_node, related_name='assets', verbose_name=_("Nodes")
)

View File

@@ -394,7 +394,7 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
def get_all_assets(self):
from .asset import Asset
q = Q(nodes__key__startswith=f'{self.key}:') | Q(nodes__key=self.key)
q = Q(node__key__startswith=f'{self.key}:') | Q(node__key=self.key)
return Asset.objects.filter(q).distinct()
def get_assets_amount(self):
@@ -416,8 +416,8 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
def get_assets(self):
from .asset import Asset
assets = Asset.objects.filter(nodes=self)
return assets.distinct()
assets = Asset.objects.filter(node=self)
return assets
def get_valid_assets(self):
return self.get_assets().valid()
@@ -531,6 +531,15 @@ class SomeNodesMixin:
root_nodes = cls.objects.filter(parent_key='', key__regex=r'^[0-9]+$') \
.exclude(key__startswith='-').order_by('key')
return root_nodes
@classmethod
def get_or_create_org_root(cls, org):
org_root = cls.org_root_nodes().filter(org_id=org.id).first()
if org_root:
return org_root
with tmp_to_org(org):
org_root = cls.create_org_root_node()
return org_root
class Node(JMSOrgBaseModel, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):

View File

@@ -186,6 +186,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
super().__init__(*args, **kwargs)
self._init_field_choices()
self._extract_accounts()
self._set_platform()
def _extract_accounts(self):
if not getattr(self, 'initial_data', None):
@@ -217,6 +218,21 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
protocols_data = [{'name': p.name, 'port': p.port} for p in protocols]
self.initial_data['protocols'] = protocols_data
def _set_platform(self):
if not hasattr(self, 'initial_data'):
return
platform_id = self.initial_data.get('platform')
if not platform_id:
return
if isinstance(platform_id, int) or str(platform_id).isdigit() or not isinstance(platform_id, str):
return
platform = Platform.objects.filter(name=platform_id).first()
if not platform:
return
self.initial_data['platform'] = platform.id
def _init_field_choices(self):
request = self.context.get('request')
if not request:
@@ -265,8 +281,10 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
if not platform_id and self.instance:
platform = self.instance.platform
else:
elif isinstance(platform_id, int):
platform = Platform.objects.filter(id=platform_id).first()
else:
platform = Platform.objects.filter(name=platform_id).first()
if not platform:
raise serializers.ValidationError({'platform': _("Platform not exist")})
@@ -297,6 +315,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
def is_valid(self, raise_exception=False):
self._set_protocols_default()
self._set_platform()
return super().is_valid(raise_exception=raise_exception)
def validate_protocols(self, protocols_data):

View File

@@ -0,0 +1 @@
from .asset_tree import *

View File

@@ -0,0 +1,246 @@
from collections import defaultdict
from django.db.models import Count, Q
from orgs.utils import current_org
from orgs.models import Organization
from assets.models import Asset, Node, Platform
from assets.const.category import Category
from common.utils import get_logger, timeit, lazyproperty
from .tree import TreeNode, Tree
logger = get_logger(__name__)
__all__ = ['AssetTree', 'AssetTreeNode']
class AssetTreeNodeAsset:
def __init__(self, id, node_id, parent_key, name, address,
platform_id, is_active, comment, org_id):
self.id = id
self.node_id = node_id
self.parent_key = parent_key
self.name = name
self.address = address
self.platform_id = platform_id
self.is_active = is_active
self.comment = comment
self.org_id = org_id
@lazyproperty
def org(self):
return Organization.get_instance(self.org_id)
@property
def org_name(self) -> str:
return self.org.name
class AssetTreeNode(TreeNode):
def __init__(self, _id, key, value, assets_amount=0, assets=None):
super().__init__(_id, key, value)
self.assets_amount = assets_amount
self.assets_amount_total = 0
self.assets: list[AssetTreeNodeAsset] = []
self.init_assets(assets)
def init_assets(self, assets):
if not assets:
return
for asset in assets:
asset['parent_key'] = self.key
self.assets.append(AssetTreeNodeAsset(**asset))
def get_assets(self):
return self.assets
def as_dict(self, simple=True):
data = super().as_dict(simple=simple)
data.update({
'assets_amount_total': self.assets_amount_total,
'assets_amount': self.assets_amount,
'assets': len(self.assets),
})
return data
class AssetTree(Tree):
TreeNode = AssetTreeNode
def __init__(self, assets_q_object: Q = None, category=None, org=None,
with_assets=False, with_assets_node_id=None, with_assets_limit=1000,
full_tree=True):
'''
:param assets_q_object: 只生成这些资产所在的节点树
:param category: 只生成该类别资产所在的节点树
:param org: 只生成该组织下的资产节点树
:param with_assets_node_id: 仅指定节点下包含资产
:param with_assets: 所有节点都包含资产
:param with_assets_limit: 包含资产时, 所有资产的最大数量
:param full_tree: 完整树包含所有节点否则只包含节点的资产总数不为0的节点
'''
super().__init__()
## 通过资产构建节点树, 支持 Q, category, org 等过滤条件 ##
self._assets_q_object: Q = assets_q_object or Q()
self._category = self._check_category(category)
self._category_platform_ids = set()
self._org: Organization = org or current_org
# org 下全量节点属性映射, 构建资产树时根据完整的节点进行构建
self._nodes_attr_mapper = defaultdict(dict)
# 节点直接资产数量映射, 用于计算节点下总资产数量
self._nodes_assets_amount_mapper = defaultdict(int)
# 节点下是否包含资产
self._with_assets = with_assets # 所有节点都包含资产
self._with_assets_node_id = with_assets_node_id # 仅指定节点下包含资产, 优先级高于 with_assets
self._with_assets_limit = with_assets_limit
self._node_assets_mapper = defaultdict(dict)
# 是否包含资产总数量为 0 的节点
self._full_tree = full_tree
# 初始化时构建树
self.build()
def _check_category(self, category):
if category is None:
return None
if category in Category.values:
return category
logger.warning(f"Invalid category '{category}' for AssetSearchTree.")
return None
@timeit
def build(self):
self._load_nodes_attr_mapper()
self._load_category_platforms_if_needed()
self._load_nodes_assets_amount()
self._load_nodes_assets_if_needed()
self._init_tree()
self._compute_assets_amount_total()
self._remove_nodes_with_zero_assets_if_needed()
@timeit
def _load_category_platforms_if_needed(self):
if self._category is None:
return
ids = Platform.objects.filter(category=self._category).values_list('id', flat=True)
ids = self._uuids_to_string(ids)
self._category_platform_ids = ids
@timeit
def _load_nodes_attr_mapper(self):
nodes = Node.objects.filter(org_id=self._org.id).values('id', 'key', 'value')
# 保证节点按 key 顺序加载,以便后续构建树时父节点总在子节点前面
nodes = sorted(nodes, key=lambda n: [int(i) for i in n['key'].split(':')])
for node in list(nodes):
node['id'] = str(node['id'])
self._nodes_attr_mapper[node['id']] = node
@timeit
def _load_nodes_assets_amount(self):
q = self._make_assets_q_object()
nodes_amount = Asset.objects.filter(q).values('node_id').annotate(
amount=Count('id')
).values('node_id', 'amount')
for nc in list(nodes_amount):
nid = str(nc['node_id'])
self._nodes_assets_amount_mapper[nid] = nc['amount']
@timeit
def _load_nodes_assets_if_needed(self):
need_load = self._with_assets or self._with_assets_node_id
if not need_load:
return
q = self._make_assets_q_object()
if self._with_assets_node_id:
# 仅指定节点下包含资产,优先级高于 with_assets
q &= Q(node_id=self._with_assets_node_id)
assets = Asset.objects.filter(q).values(
'node_id', 'id', 'platform_id', 'name', 'address', 'is_active', 'comment', 'org_id'
)
# 按照 node_key 排序,尽可能保证前面节点的资产较多
# 限制资产数量
assets = assets.order_by('node__key')[:self._with_assets_limit]
for asset in list(assets):
nid = asset['node_id'] = str(asset['node_id'])
aid = asset['id'] = str(asset['id'])
self._node_assets_mapper[nid][aid] = asset
@timeit
def _make_assets_q_object(self) -> Q:
q = Q(org_id=self._org.id)
if self._category_platform_ids:
q &= Q(platform_id__in=self._category_platform_ids)
if self._assets_q_object:
q &= self._assets_q_object
return q
@timeit
def _init_tree(self):
for nid in self._nodes_attr_mapper.keys():
data = self._get_tree_node_data(nid)
node = self.TreeNode(**data)
self.add_node(node)
def _get_tree_node_data(self, node_id):
attr = self._nodes_attr_mapper[node_id]
assets_amount = self._nodes_assets_amount_mapper.get(node_id, 0)
data = {
'_id': node_id,
'key': attr['key'],
'value': attr['value'],
'assets_amount': assets_amount,
}
assets = self._node_assets_mapper[node_id].values()
if assets:
assets = list(assets)
data.update({ 'assets': assets })
return data
@timeit
def _compute_assets_amount_total(self):
for node in reversed(list(self.nodes.values())):
total = node.assets_amount
for child in node.children:
child: AssetTreeNode
total += child.assets_amount_total
node: AssetTreeNode
node.assets_amount_total = total
@timeit
def _remove_nodes_with_zero_assets_if_needed(self):
if self._full_tree:
return
nodes: list[AssetTreeNode] = list(self.nodes.values())
nodes_to_remove = [
node for node in nodes if not node.is_root and node.assets_amount_total == 0
]
for node in nodes_to_remove:
self.remove_node(node)
def get_assets(self):
assets = []
for node in self.nodes.values():
node: AssetTreeNode
_assets = node.get_assets()
assets.extend(_assets)
return assets
def _uuids_to_string(self, uuids):
return [ str(u) for u in uuids ]
def print(self, count=20, simple=True):
print('org_name: ', getattr(self._org, 'name', 'No-org'))
print(f'asset_category: {self._category}')
super().print(count=count, simple=simple)

164
apps/assets/tree/tree.py Normal file
View File

@@ -0,0 +1,164 @@
from common.utils import get_logger, lazyproperty
__all__ = ['TreeNode', 'Tree']
logger = get_logger(__name__)
class TreeNode(object):
def __init__(self, _id, key, value):
self.id = _id
self.key = key
self.value = value
self.children = []
self.parent = None
@lazyproperty
def parent_key(self):
if self.is_root:
return None
return ':'.join(self.key.split(':')[:-1])
@property
def is_root(self):
return self.key.isdigit()
def add_child(self, child_node: 'TreeNode'):
child_node.parent = self
self.children.append(child_node)
def remove_child(self, child_node: 'TreeNode'):
self.children.remove(child_node)
child_node.parent = None
@property
def is_leaf(self):
return len(self.children) == 0
@lazyproperty
def level(self):
return self.key.count(':') + 1
@property
def children_count(self):
return len(self.children)
def as_dict(self, simple=True):
data = {
'key': self.key,
}
if simple:
return data
data.update({
'id': self.id,
'value': self.value,
'level': self.level,
'children_count': self.children_count,
'is_root': self.is_root,
'is_leaf': self.is_leaf,
})
return data
def print(self, simple=True, is_print_keys=False):
def info_as_string(_info):
return ' | '.join(s.ljust(25) for s in _info)
if is_print_keys:
info_keys = [k for k in self.as_dict(simple=simple).keys()]
info_keys_string = info_as_string(info_keys)
print('-' * len(info_keys_string))
print(info_keys_string)
print('-' * len(info_keys_string))
info_values = [str(v) for v in self.as_dict(simple=simple).values()]
info_values_as_string = info_as_string(info_values)
print(info_values_as_string)
print('-' * len(info_values_as_string))
class Tree(object):
def __init__(self):
self.root = None
# { key -> TreeNode }
self.nodes: dict[TreeNode] = {}
@property
def size(self):
return len(self.nodes)
@property
def is_empty(self):
return self.size == 0
@property
def depth(self):
" 返回树的最大深度以及对应的节点key "
if self.is_empty:
return 0, 0
node = max(self.nodes.values(), key=lambda node: node.level)
node: TreeNode
print(f"max_depth_node_key: {node.key}")
return node.level
@property
def width(self):
" 返回树的最大宽度,以及对应的层级数 "
if self.is_empty:
return 0, 0
node = max(self.nodes.values(), key=lambda node: node.children_count)
node: TreeNode
print(f"max_width_level: {node.level + 1}")
return node.children_count
def add_node(self, node: TreeNode):
if node.is_root:
self.root = node
self.nodes[node.key] = node
return
parent = self.get_node(node.parent_key)
if not parent:
error = f""" Cannot add node {node.key}: parent key {node.parent_key} not found.
Please ensure parent nodes are added before child nodes."""
raise ValueError(error)
parent.add_child(node)
self.nodes[node.key] = node
def get_node(self, key: str) -> TreeNode:
return self.nodes.get(key)
def remove_node(self, node: TreeNode):
if node.is_root:
self.root = None
else:
parent: TreeNode = node.parent
parent.remove_child(node)
self.nodes.pop(node.key, None)
def get_nodes(self):
return list(self.nodes.values())
def get_node_children(self, key, with_self=False):
node = self.get_node(key)
if not node:
return []
nodes = []
if with_self:
nodes.append(node)
nodes.extend(node.children)
return nodes
def print(self, count=10, simple=True):
print('tree_root_key: ', getattr(self.root, 'key', 'No-root'))
print('tree_size: ', self.size)
print('tree_depth: ', self.depth)
print('tree_width: ', self.width)
is_print_key = True
for n in list(self.nodes.values())[:count]:
n: TreeNode
n.print(simple=simple, is_print_keys=is_print_key)
is_print_key = False

View File

@@ -43,7 +43,7 @@ from .serializers import (
OperateLogSerializer, OperateLogActionDetailSerializer,
PasswordChangeLogSerializer, ActivityUnionLogSerializer,
FileSerializer, UserSessionSerializer, JobsAuditSerializer,
ServiceAccessLogSerializer
ServiceAccessLogSerializer, OperateLogFullSerializer
)
from .utils import construct_userlogin_usernames, record_operate_log_and_activity_log
@@ -256,7 +256,9 @@ class OperateLogViewSet(OrgReadonlyModelViewSet):
def get_serializer_class(self):
if self.is_action_detail:
return OperateLogActionDetailSerializer
return super().get_serializer_class()
elif self.request.query_params.get('format'):
return OperateLogFullSerializer
return OperateLogSerializer
def get_queryset(self):
current_org_id = str(current_org.id)

View File

@@ -127,6 +127,21 @@ class OperateLogSerializer(BulkOrgResourceModelSerializer):
return i18n_trans(instance.resource)
class DiffFieldSerializer(serializers.JSONField):
def to_file_representation(self, value):
row = getattr(self, '_row') or {}
attrs = {'diff': value, 'resource_type': row.get('resource_type')}
instance = type('OperateLog', (), attrs)
return OperateLogStore.convert_diff_friendly(instance)
class OperateLogFullSerializer(OperateLogSerializer):
diff = DiffFieldSerializer(label=_("Diff"))
class Meta(OperateLogSerializer.Meta):
fields = OperateLogSerializer.Meta.fields + ['diff']
class PasswordChangeLogSerializer(serializers.ModelSerializer):
class Meta:
model = models.PasswordChangeLog

View File

@@ -16,3 +16,4 @@ from .sso import *
from .temp_token import *
from .token import *
from .face import *
from .access_token import *

View File

@@ -0,0 +1,47 @@
from django.shortcuts import get_object_or_404
from django.utils.translation import gettext as _
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from oauth2_provider.models import get_access_token_model
from common.api import JMSModelViewSet
from rbac.permissions import RBACPermission
from ..serializers import AccessTokenSerializer
AccessToken = get_access_token_model()
class AccessTokenViewSet(JMSModelViewSet):
"""
OAuth2 Access Token 管理视图集
用户只能查看和撤销自己的 access token
"""
serializer_class = AccessTokenSerializer
permission_classes = [RBACPermission]
http_method_names = ['get', 'options', 'delete']
rbac_perms = {
'revoke': 'oauth2_provider.delete_accesstoken',
}
def get_queryset(self):
"""只返回当前用户的 access token按创建时间倒序"""
return AccessToken.objects.filter(user=self.request.user).order_by('-created')
@action(methods=['DELETE'], detail=True, url_path='revoke')
def revoke(self, request, *args, **kwargs):
"""
撤销 access token 及其关联的 refresh token
如果 token 不存在或不属于当前用户,返回 404
"""
token = get_object_or_404(
AccessToken.objects.filter(user=request.user),
id=kwargs['pk']
)
# 优先撤销 refresh token会自动撤销关联的 access token
token_to_revoke = token.refresh_token if token.refresh_token else token
token_to_revoke.revoke()
return Response(status=HTTP_204_NO_CONTENT)

View File

@@ -1,6 +1,5 @@
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
from django.views import View
from common.utils import get_logger
from users.models import User
@@ -66,11 +65,3 @@ class JMSBaseAuthBackend:
class JMSModelBackend(JMSBaseAuthBackend, ModelBackend):
def user_can_authenticate(self, user):
return True
class BaseAuthCallbackClientView(View):
http_method_names = ['get']
def get(self, request):
from authentication.views.utils import redirect_to_guard_view
return redirect_to_guard_view(query_string='next=client')

View File

@@ -1,51 +1,22 @@
# -*- coding: utf-8 -*-
#
import threading
from django.conf import settings
from django.contrib.auth import get_user_model
from django_cas_ng.backends import CASBackend as _CASBackend
from common.utils import get_logger
from ..base import JMSBaseAuthBackend
__all__ = ['CASBackend', 'CASUserDoesNotExist']
__all__ = ['CASBackend']
logger = get_logger(__name__)
class CASUserDoesNotExist(Exception):
"""Exception raised when a CAS user does not exist."""
pass
class CASBackend(JMSBaseAuthBackend, _CASBackend):
@staticmethod
def is_enabled():
return settings.AUTH_CAS
def authenticate(self, request, ticket, service):
UserModel = get_user_model()
manager = UserModel._default_manager
original_get_by_natural_key = manager.get_by_natural_key
thread_local = threading.local()
thread_local.thread_id = threading.get_ident()
logger.debug(f"CASBackend.authenticate: thread_id={thread_local.thread_id}")
def get_by_natural_key(self, username):
logger.debug(f"CASBackend.get_by_natural_key: thread_id={threading.get_ident()}, username={username}")
if threading.get_ident() != thread_local.thread_id:
return original_get_by_natural_key(username)
try:
user = original_get_by_natural_key(username)
except UserModel.DoesNotExist:
raise CASUserDoesNotExist(username)
return user
try:
manager.get_by_natural_key = get_by_natural_key.__get__(manager, type(manager))
user = super().authenticate(request, ticket=ticket, service=service)
finally:
manager.get_by_natural_key = original_get_by_natural_key
return user
# 这里做个hack ,让父类始终走CAS_CREATE_USER=True的逻辑然后调用 authentication/mixins.py 中的 custom_get_or_create 方法
settings.CAS_CREATE_USER = True
return super().authenticate(request, ticket, service)

View File

@@ -3,11 +3,10 @@
import django_cas_ng.views
from django.urls import path
from .views import CASLoginView, CASCallbackClientView
from .views import CASLoginView
urlpatterns = [
path('login/', CASLoginView.as_view(), name='cas-login'),
path('logout/', django_cas_ng.views.LogoutView.as_view(), name='cas-logout'),
path('callback/', django_cas_ng.views.CallbackView.as_view(), name='cas-proxy-callback'),
path('login/client', CASCallbackClientView.as_view(), name='cas-proxy-callback-client'),
path('callback/', django_cas_ng.views.CallbackView.as_view(), name='cas-proxy-callback')
]

View File

@@ -3,31 +3,20 @@ from django.http import HttpResponseRedirect
from django.utils.translation import gettext_lazy as _
from django_cas_ng.views import LoginView
from authentication.backends.base import BaseAuthCallbackClientView
from common.utils import FlashMessageUtil
from .backends import CASUserDoesNotExist
from authentication.views.mixins import FlashMessageMixin
__all__ = ['LoginView']
class CASLoginView(LoginView):
class CASLoginView(LoginView, FlashMessageMixin):
def get(self, request):
try:
resp = super().get(request)
return resp
except PermissionDenied:
return HttpResponseRedirect('/')
except CASUserDoesNotExist as e:
message_data = {
'title': _('User does not exist: {}').format(e),
'error': _(
'CAS login was successful, but no corresponding local user was found in the system, and automatic '
'user creation is disabled in the CAS authentication configuration. Login failed.'),
'interval': 10,
'redirect_url': '/',
}
return FlashMessageUtil.gen_and_redirect_to(message_data)
class CASCallbackClientView(BaseAuthCallbackClientView):
pass
resp = HttpResponseRedirect('/')
error_message = getattr(request, 'error_message', '')
if error_message:
response = self.get_failed_response('/', title=_('CAS Error'), msg=error_message)
return response
else:
return resp

View File

@@ -69,6 +69,8 @@ class AccessTokenAuthentication(authentication.BaseAuthentication):
msg = _('Invalid token header. Sign string should not contain invalid characters.')
raise exceptions.AuthenticationFailed(msg)
user, header = self.authenticate_credentials(token)
if not user:
return None
after_authenticate_update_date(user)
return user, header
@@ -77,10 +79,6 @@ class AccessTokenAuthentication(authentication.BaseAuthentication):
model = get_user_model()
user_id = cache.get(token)
user = get_object_or_none(model, id=user_id)
if not user:
msg = _('Invalid token or cache refreshed.')
raise exceptions.AuthenticationFailed(msg)
return user, None
def authenticate_header(self, request):

View File

@@ -7,6 +7,5 @@ from . import views
urlpatterns = [
path('login/', views.OAuth2AuthRequestView.as_view(), name='login'),
path('callback/', views.OAuth2AuthCallbackView.as_view(), name='login-callback'),
path('callback/client/', views.OAuth2AuthCallbackClientView.as_view(), name='login-callback-client'),
path('logout/', views.OAuth2EndSessionView.as_view(), name='logout')
]

View File

@@ -6,27 +6,34 @@ from django.utils.http import urlencode
from django.utils.translation import gettext_lazy as _
from django.views import View
from authentication.backends.base import BaseAuthCallbackClientView
from authentication.decorators import pre_save_next_to_session, redirect_to_pre_save_next_after_auth
from authentication.mixins import authenticate
from authentication.utils import build_absolute_uri
from authentication.views.mixins import FlashMessageMixin
from common.utils import get_logger
from common.utils import get_logger, safe_next_url
logger = get_logger(__file__)
class OAuth2AuthRequestView(View):
@pre_save_next_to_session()
def get(self, request):
log_prompt = "Process OAuth2 GET requests: {}"
logger.debug(log_prompt.format('Start'))
request_params = request.GET.dict()
request_params.pop('next', None)
query = urlencode(request_params)
redirect_uri = build_absolute_uri(
request, path=reverse(settings.AUTH_OAUTH2_AUTH_LOGIN_CALLBACK_URL_NAME)
)
redirect_uri = f"{redirect_uri}?{query}"
query_dict = {
'client_id': settings.AUTH_OAUTH2_CLIENT_ID, 'response_type': 'code',
'scope': settings.AUTH_OAUTH2_SCOPE,
'redirect_uri': build_absolute_uri(
request, path=reverse(settings.AUTH_OAUTH2_AUTH_LOGIN_CALLBACK_URL_NAME)
)
'redirect_uri': redirect_uri
}
if '?' in settings.AUTH_OAUTH2_PROVIDER_AUTHORIZATION_ENDPOINT:
@@ -45,6 +52,7 @@ class OAuth2AuthRequestView(View):
class OAuth2AuthCallbackView(View, FlashMessageMixin):
http_method_names = ['get', ]
@redirect_to_pre_save_next_after_auth
def get(self, request):
""" Processes GET requests. """
log_prompt = "Process GET requests [OAuth2AuthCallbackView]: {}"
@@ -59,9 +67,7 @@ class OAuth2AuthCallbackView(View, FlashMessageMixin):
logger.debug(log_prompt.format('Login: {}'.format(user)))
auth.login(self.request, user)
logger.debug(log_prompt.format('Redirect'))
return HttpResponseRedirect(
settings.AUTH_OAUTH2_AUTHENTICATION_REDIRECT_URI
)
return HttpResponseRedirect(settings.AUTH_OAUTH2_AUTHENTICATION_REDIRECT_URI)
else:
if getattr(request, 'error_message', ''):
response = self.get_failed_response('/', title=_('OAuth2 Error'), msg=request.error_message)
@@ -72,10 +78,6 @@ class OAuth2AuthCallbackView(View, FlashMessageMixin):
return HttpResponseRedirect(redirect_url)
class OAuth2AuthCallbackClientView(BaseAuthCallbackClientView):
pass
class OAuth2EndSessionView(View):
http_method_names = ['get', 'post', ]

View File

@@ -0,0 +1,20 @@
from django.db.models.signals import post_delete
from django.dispatch import receiver
from django.core.cache import cache
from django.conf import settings
from oauth2_provider.models import get_application_model
from .utils import clear_oauth2_authorization_server_view_cache
__all__ = ['on_oauth2_provider_application_deleted']
Application = get_application_model()
@receiver(post_delete, sender=Application)
def on_oauth2_provider_application_deleted(sender, instance, **kwargs):
if instance.name == settings.OAUTH2_PROVIDER_JUMPSERVER_CLIENT_NAME:
clear_oauth2_authorization_server_view_cache()

View File

@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
#
from django.urls import path
from oauth2_provider import views as op_views
from . import views
urlpatterns = [
path("authorize/", op_views.AuthorizationView.as_view(), name="authorize"),
path("token/", op_views.TokenView.as_view(), name="token"),
path("revoke/", op_views.RevokeTokenView.as_view(), name="revoke-token"),
path(".well-known/oauth-authorization-server", views.OAuthAuthorizationServerView.as_view(), name="oauth-authorization-server"),
]

View File

@@ -0,0 +1,31 @@
from django.conf import settings
from django.core.cache import cache
from oauth2_provider.models import get_application_model
from common.utils import get_logger
logger = get_logger(__name__)
def get_or_create_jumpserver_client_application():
"""Auto get or create OAuth2 JumpServer Client application."""
Application = get_application_model()
application, created = Application.objects.get_or_create(
name=settings.OAUTH2_PROVIDER_JUMPSERVER_CLIENT_NAME,
defaults={
'client_type': Application.CLIENT_PUBLIC,
'authorization_grant_type': Application.GRANT_AUTHORIZATION_CODE,
'redirect_uris': settings.OAUTH2_PROVIDER_CLIENT_REDIRECT_URI,
'skip_authorization': True,
}
)
return application
CACHE_OAUTH_SERVER_VIEW_KEY_PREFIX = 'oauth2_provider_metadata'
def clear_oauth2_authorization_server_view_cache():
logger.info("Clearing OAuth2 Authorization Server Metadata view cache")
cache_key = f'views.decorators.cache.cache_page.{CACHE_OAUTH_SERVER_VIEW_KEY_PREFIX}.GET*'
cache.delete_pattern(cache_key)

View File

@@ -0,0 +1,77 @@
from django.views.generic import View
from django.http import JsonResponse
from django.utils.decorators import method_decorator
from django.views.decorators.cache import cache_page
from django.views.decorators.csrf import csrf_exempt
from django.conf import settings
from django.urls import reverse
from oauth2_provider.settings import oauth2_settings
from typing import List, Dict, Any
from .utils import get_or_create_jumpserver_client_application, CACHE_OAUTH_SERVER_VIEW_KEY_PREFIX
@method_decorator(csrf_exempt, name='dispatch')
@method_decorator(cache_page(timeout=60 * 60 * 24, key_prefix=CACHE_OAUTH_SERVER_VIEW_KEY_PREFIX), name='dispatch')
class OAuthAuthorizationServerView(View):
"""
OAuth 2.0 Authorization Server Metadata Endpoint
RFC 8414: https://datatracker.ietf.org/doc/html/rfc8414
This endpoint provides machine-readable information about the
OAuth 2.0 authorization server's configuration.
"""
def get_base_url(self, request) -> str:
scheme = 'https' if request.is_secure() else 'http'
host = request.get_host()
return f"{scheme}://{host}"
def get_supported_scopes(self) -> List[str]:
scopes_config = oauth2_settings.SCOPES
if isinstance(scopes_config, dict):
return list(scopes_config.keys())
return []
def get_metadata(self, request) -> Dict[str, Any]:
base_url = self.get_base_url(request)
application = get_or_create_jumpserver_client_application()
metadata = {
"issuer": base_url,
"client_id": application.client_id if application else "Not found any application.",
"authorization_endpoint": base_url + reverse('authentication:oauth2-provider:authorize'),
"token_endpoint": base_url + reverse('authentication:oauth2-provider:token'),
"revocation_endpoint": base_url + reverse('authentication:oauth2-provider:revoke-token'),
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"scopes_supported": self.get_supported_scopes(),
"token_endpoint_auth_methods_supported": ["none"],
"revocation_endpoint_auth_methods_supported": ["none"],
"code_challenge_methods_supported": ["S256"],
"response_modes_supported": ["query"],
}
if hasattr(oauth2_settings, 'ACCESS_TOKEN_EXPIRE_SECONDS'):
metadata["token_expires_in"] = oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS
if hasattr(oauth2_settings, 'REFRESH_TOKEN_EXPIRE_SECONDS'):
if oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS:
metadata["refresh_token_expires_in"] = oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS
return metadata
def get(self, request, *args, **kwargs):
metadata = self.get_metadata(request)
response = JsonResponse(metadata)
self.add_cors_headers(response)
return response
def options(self, request, *args, **kwargs):
response = JsonResponse({})
self.add_cors_headers(response)
return response
@staticmethod
def add_cors_headers(response):
response['Access-Control-Allow-Origin'] = '*'
response['Access-Control-Allow-Methods'] = 'GET, OPTIONS'
response['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response['Access-Control-Max-Age'] = '3600'

View File

@@ -15,6 +15,5 @@ from . import views
urlpatterns = [
path('login/', views.OIDCAuthRequestView.as_view(), name='login'),
path('callback/', views.OIDCAuthCallbackView.as_view(), name='login-callback'),
path('callback/client/', views.OIDCAuthCallbackClientView.as_view(), name='login-callback-client'),
path('logout/', views.OIDCEndSessionView.as_view(), name='logout'),
]

View File

@@ -25,11 +25,11 @@ from django.utils.http import urlencode
from django.utils.translation import gettext_lazy as _
from django.views.generic import View
from authentication.decorators import pre_save_next_to_session, redirect_to_pre_save_next_after_auth
from authentication.utils import build_absolute_uri_for_oidc
from authentication.views.mixins import FlashMessageMixin
from common.utils import safe_next_url
from .utils import get_logger
from ..base import BaseAuthCallbackClientView
logger = get_logger(__file__)
@@ -58,6 +58,7 @@ class OIDCAuthRequestView(View):
b = base64.urlsafe_b64encode(h)
return b.decode('ascii')[:-1]
@pre_save_next_to_session()
def get(self, request):
""" Processes GET requests. """
@@ -66,8 +67,9 @@ class OIDCAuthRequestView(View):
# Defines common parameters used to bootstrap the authentication request.
logger.debug(log_prompt.format('Construct request params'))
authentication_request_params = request.GET.dict()
authentication_request_params.update({
request_params = request.GET.dict()
request_params.pop('next', None)
request_params.update({
'scope': settings.AUTH_OPENID_SCOPES,
'response_type': 'code',
'client_id': settings.AUTH_OPENID_CLIENT_ID,
@@ -80,7 +82,7 @@ class OIDCAuthRequestView(View):
code_verifier = self.gen_code_verifier()
code_challenge_method = settings.AUTH_OPENID_CODE_CHALLENGE_METHOD or 'S256'
code_challenge = self.gen_code_challenge(code_verifier, code_challenge_method)
authentication_request_params.update({
request_params.update({
'code_challenge_method': code_challenge_method,
'code_challenge': code_challenge
})
@@ -91,7 +93,7 @@ class OIDCAuthRequestView(View):
if settings.AUTH_OPENID_USE_STATE:
logger.debug(log_prompt.format('Use state'))
state = get_random_string(settings.AUTH_OPENID_STATE_LENGTH)
authentication_request_params.update({'state': state})
request_params.update({'state': state})
request.session['oidc_auth_state'] = state
# Nonces should be used too! In that case the generated nonce is stored both in the
@@ -99,17 +101,12 @@ class OIDCAuthRequestView(View):
if settings.AUTH_OPENID_USE_NONCE:
logger.debug(log_prompt.format('Use nonce'))
nonce = get_random_string(settings.AUTH_OPENID_NONCE_LENGTH)
authentication_request_params.update({'nonce': nonce, })
request_params.update({'nonce': nonce, })
request.session['oidc_auth_nonce'] = nonce
# Stores the "next" URL in the session if applicable.
logger.debug(log_prompt.format('Stores next url in the session'))
next_url = request.GET.get('next')
request.session['oidc_auth_next_url'] = safe_next_url(next_url, request=request)
# Redirects the user to authorization endpoint.
logger.debug(log_prompt.format('Construct redirect url'))
query = urlencode(authentication_request_params)
query = urlencode(request_params)
redirect_url = '{url}?{query}'.format(
url=settings.AUTH_OPENID_PROVIDER_AUTHORIZATION_ENDPOINT, query=query)
@@ -129,6 +126,8 @@ class OIDCAuthCallbackView(View, FlashMessageMixin):
http_method_names = ['get', ]
@redirect_to_pre_save_next_after_auth
def get(self, request):
""" Processes GET requests. """
log_prompt = "Process GET requests [OIDCAuthCallbackView]: {}"
@@ -167,7 +166,6 @@ class OIDCAuthCallbackView(View, FlashMessageMixin):
raise SuspiciousOperation('Invalid OpenID Connect callback state value')
# Authenticates the end-user.
next_url = request.session.get('oidc_auth_next_url', None)
code_verifier = request.session.get('oidc_auth_code_verifier', None)
logger.debug(log_prompt.format('Process authenticate'))
try:
@@ -191,9 +189,7 @@ class OIDCAuthCallbackView(View, FlashMessageMixin):
callback_params.get('session_state', None)
logger.debug(log_prompt.format('Redirect'))
return HttpResponseRedirect(
next_url or settings.AUTH_OPENID_AUTHENTICATION_REDIRECT_URI
)
return HttpResponseRedirect(settings.AUTH_OPENID_AUTHENTICATION_REDIRECT_URI)
if 'error' in callback_params:
logger.debug(
log_prompt.format('Error in callback params: {}'.format(callback_params['error']))
@@ -212,10 +208,6 @@ class OIDCAuthCallbackView(View, FlashMessageMixin):
return HttpResponseRedirect(redirect_url)
class OIDCAuthCallbackClientView(BaseAuthCallbackClientView):
pass
class OIDCEndSessionView(View):
""" Allows to end the session of any user authenticated using OpenID Connect.

View File

@@ -8,6 +8,5 @@ urlpatterns = [
path('login/', views.Saml2AuthRequestView.as_view(), name='saml2-login'),
path('logout/', views.Saml2EndSessionView.as_view(), name='saml2-logout'),
path('callback/', views.Saml2AuthCallbackView.as_view(), name='saml2-callback'),
path('callback/client/', views.Saml2AuthCallbackClientView.as_view(), name='saml2-callback-client'),
path('metadata/', views.Saml2AuthMetadataView.as_view(), name='saml2-metadata'),
]

View File

@@ -17,9 +17,8 @@ from onelogin.saml2.idp_metadata_parser import (
)
from authentication.views.mixins import FlashMessageMixin
from common.utils import get_logger
from common.utils import get_logger, safe_next_url
from .settings import JmsSaml2Settings
from ..base import BaseAuthCallbackClientView
logger = get_logger(__file__)
@@ -208,13 +207,16 @@ class Saml2AuthRequestView(View, PrepareRequestMixin):
log_prompt = "Process SAML GET requests: {}"
logger.debug(log_prompt.format('Start'))
request_params = request.GET.dict()
try:
saml_instance = self.init_saml_auth(request)
except OneLogin_Saml2_Error as error:
logger.error(log_prompt.format('Init saml auth error: %s' % error))
return HttpResponse(error, status=412)
next_url = settings.AUTH_SAML2_PROVIDER_AUTHORIZATION_ENDPOINT
next_url = request_params.get('next') or settings.AUTH_SAML2_PROVIDER_AUTHORIZATION_ENDPOINT
next_url = safe_next_url(next_url, request=request)
url = saml_instance.login(return_to=next_url)
logger.debug(log_prompt.format('Redirect login url'))
return HttpResponseRedirect(url)
@@ -293,10 +295,11 @@ class Saml2AuthCallbackView(View, PrepareRequestMixin, FlashMessageMixin):
return response
logger.debug(log_prompt.format('Redirect'))
redir = post_data.get('RelayState')
if not redir or len(redir) == 0:
redir = "/"
next_url = saml_instance.redirect_to(redir)
relay_state = post_data.get('RelayState')
if not relay_state or len(relay_state) == 0:
relay_state = "/"
next_url = saml_instance.redirect_to(relay_state)
next_url = safe_next_url(next_url, request=request)
return HttpResponseRedirect(next_url)
@csrf_exempt
@@ -304,10 +307,6 @@ class Saml2AuthCallbackView(View, PrepareRequestMixin, FlashMessageMixin):
return super().dispatch(*args, **kwargs)
class Saml2AuthCallbackClientView(BaseAuthCallbackClientView):
pass
class Saml2AuthMetadataView(View, PrepareRequestMixin):
def get(self, request):

View File

@@ -1,6 +1,9 @@
from django.db.models import TextChoices
from django.utils.translation import gettext_lazy as _
USER_LOGIN_GUARD_VIEW_REDIRECT_FIELD = 'next'
RSA_PRIVATE_KEY = 'rsa_private_key'
RSA_PUBLIC_KEY = 'rsa_public_key'

View File

@@ -0,0 +1,193 @@
"""
This module provides decorators to handle redirect URLs during the authentication flow:
1. pre_save_next_to_session: Captures and stores the intended next URL before redirecting to auth provider
2. redirect_to_pre_save_next_after_auth: Redirects to the stored next URL after successful authentication
3. post_save_next_to_session: Copies the stored next URL to session['next'] after view execution
"""
from urllib.parse import urlparse
from django.http import HttpResponseRedirect
from django.urls import reverse
from django.utils.translation import gettext_lazy as _
from functools import wraps
from common.utils import get_logger, safe_next_url
from .const import USER_LOGIN_GUARD_VIEW_REDIRECT_FIELD
logger = get_logger(__file__)
__all__ = [
'pre_save_next_to_session', 'redirect_to_pre_save_next_after_auth',
'post_save_next_to_session_if_guard_redirect'
]
# Session key for storing the redirect URL after authentication
AUTH_SESSION_NEXT_URL_KEY = 'auth_next_url'
def pre_save_next_to_session(get_next_url=None):
"""
Decorator to capture and store the 'next' parameter into session BEFORE view execution.
This decorator is applied to the authentication request view to preserve the user's
intended destination URL before redirecting to the authentication provider.
Args:
get_next_url: Optional callable that extracts the next URL from request.
Default: lambda req: req.GET.get('next')
Usage:
# Use default (request.GET.get('next'))
@pre_save_next_to_session()
def get(self, request):
pass
# Custom extraction from POST data
@pre_save_next_to_session(get_next_url=lambda req: req.POST.get('next'))
def post(self, request):
pass
# Custom extraction from both GET and POST
@pre_save_next_to_session(
get_next_url=lambda req: req.GET.get('next') or req.POST.get('next')
)
def get(self, request):
pass
Example flow:
User accesses: /auth/oauth2/?next=/dashboard/
↓ (decorator saves '/dashboard/' to session)
Redirected to OAuth2 provider for authentication
"""
# Default function to extract next URL from request.GET
if get_next_url is None:
get_next_url = lambda req: req.GET.get('next')
def decorator(view_func):
@wraps(view_func)
def wrapper(self, request, *args, **kwargs):
next_url = get_next_url(request)
if next_url:
request.session[AUTH_SESSION_NEXT_URL_KEY] = next_url
logger.debug(f"[Auth] Saved next_url to session: {next_url}")
return view_func(self, request, *args, **kwargs)
return wrapper
return decorator
def redirect_to_pre_save_next_after_auth(view_func):
"""
Decorator to redirect to the previously saved 'next' URL after successful authentication.
This decorator is applied to the authentication callback view. After the user successfully
authenticates, if a 'next' URL was previously saved in the session (by pre_save_next_to_session),
the user will be redirected to that URL instead of the default redirect location.
Conditions for redirect:
- User must be authenticated (request.user.is_authenticated)
- Session must contain the saved next URL (AUTH_SESSION_NEXT_URL_KEY)
- The next URL must not be '/' (avoid unnecessary redirects)
- The next URL must pass security validation (safe_next_url)
If any condition fails, returns the original view response.
Usage:
@redirect_to_pre_save_next_after_auth
def get(self, request):
# Process authentication callback
if user_authenticated:
auth.login(request, user)
return HttpResponseRedirect(default_url)
Example flow:
User redirected back from OAuth2 provider: /auth/oauth2/callback/?code=xxx
↓ (view processes authentication, user becomes authenticated)
Decorator checks session for saved next URL
↓ (finds '/dashboard/' in session)
Redirects to: /dashboard/
↓ (clears saved URL from session)
"""
@wraps(view_func)
def wrapper(self, request, *args, **kwargs):
# Execute the original view method first
response = view_func(self, request, *args, **kwargs)
# Check if user has been authenticated
if request.user and request.user.is_authenticated:
# Check if session contains a saved next URL
saved_next_url = request.session.get(AUTH_SESSION_NEXT_URL_KEY)
if saved_next_url and saved_next_url != '/':
# Validate the URL for security
safe_url = safe_next_url(saved_next_url, request=request)
if safe_url:
# Clear the saved URL from session (one-time use)
request.session.pop(AUTH_SESSION_NEXT_URL_KEY, None)
logger.debug(f"[Auth] Redirecting authenticated user to saved next_url: {safe_url}")
return HttpResponseRedirect(safe_url)
# Return the original response if no redirect conditions are met
return response
return wrapper
def post_save_next_to_session_if_guard_redirect(view_func):
"""
Decorator to copy AUTH_SESSION_NEXT_URL_KEY to session['next'] after view execution,
but only if redirecting to login-guard view.
This decorator is applied AFTER view execution. It copies the value from
AUTH_SESSION_NEXT_URL_KEY (internal storage) to 'next' (standard session key)
for use by downstream code.
Only sets the 'next' session key when the response is a redirect to guard-view
(i.e., response with redirect status code and location path matching login-guard view URL).
Usage:
@post_save_next_to_session_if_guard_redirect
def get(self, request):
# Process the request and return response
if some_condition:
return self.redirect_to_guard_view() # Decorator will copy next to session
return HttpResponseRedirect(url) # Decorator won't copy if not to guard-view
Example flow:
View executes and returns redirect to guard view
↓ (response is redirect with 'login-guard' in Location)
Decorator checks if response is redirect to guard-view and session has saved next URL
↓ (copies AUTH_SESSION_NEXT_URL_KEY to session['next'])
User is redirected to guard-view with 'next' available in session
"""
@wraps(view_func)
def wrapper(self, request, *args, **kwargs):
# Execute the original view method
response = view_func(self, request, *args, **kwargs)
# Check if response is a redirect to guard view
# Redirect responses typically have status codes 301, 302, 303, 307, 308
is_guard_redirect = False
if hasattr(response, 'status_code') and response.status_code in (301, 302, 303, 307, 308):
# Check if the redirect location is to guard view
location = response.get('Location', '')
if location:
# Extract path from location URL (handle both absolute and relative URLs)
parsed_url = urlparse(location)
path = parsed_url.path
# Check if path matches guard view URL pattern
guard_view_url = reverse('authentication:login-guard')
if path == guard_view_url:
is_guard_redirect = True
# Only set 'next' if response is a redirect to guard view
if is_guard_redirect:
# Copy AUTH_SESSION_NEXT_URL_KEY to 'next' if it exists
saved_next_url = request.session.get(AUTH_SESSION_NEXT_URL_KEY)
if saved_next_url:
# 这里 'next' 是 UserLoginGuardView.redirect_field_name
request.session[USER_LOGIN_GUARD_VIEW_REDIRECT_FIELD] = saved_next_url
logger.debug(f"[Auth] Copied {AUTH_SESSION_NEXT_URL_KEY} to 'next' in session: {saved_next_url}")
return response
return wrapper

View File

@@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
#
from django.core.management.base import BaseCommand
from django.db.utils import OperationalError, ProgrammingError
from django.conf import settings
class Command(BaseCommand):
help = 'Initialize OAuth2 Provider - Create default JumpServer Client application'
def add_arguments(self, parser):
parser.add_argument(
'--force',
action='store_true',
help='Force recreate the application even if it exists',
)
def handle(self, *args, **options):
force = options.get('force', False)
try:
from authentication.backends.oauth2_provider.utils import (
get_or_create_jumpserver_client_application
)
from oauth2_provider.models import get_application_model
Application = get_application_model()
# 检查表是否存在
try:
Application.objects.exists()
except (OperationalError, ProgrammingError) as e:
self.stdout.write(
self.style.ERROR(
f'OAuth2 Provider tables not found. Please run migrations first:\n'
f' python manage.py migrate oauth2_provider\n'
f'Error: {e}'
)
)
return
# 如果强制重建,先删除已存在的应用
if force:
deleted_count, _ = Application.objects.filter(
name=settings.OAUTH2_PROVIDER_JUMPSERVER_CLIENT_NAME
).delete()
if deleted_count > 0:
self.stdout.write(
self.style.WARNING(f'Deleted {deleted_count} existing application(s)')
)
# 创建或获取应用
application = get_or_create_jumpserver_client_application()
if application:
self.stdout.write(
self.style.SUCCESS(
f'✓ OAuth2 JumpServer Client application initialized successfully\n'
f' - Client ID: {application.client_id}\n'
f' - Client Type: {application.get_client_type_display()}\n'
f' - Grant Type: {application.get_authorization_grant_type_display()}\n'
f' - Redirect URIs: {application.redirect_uris}\n'
f' - Skip Authorization: {application.skip_authorization}'
)
)
else:
self.stdout.write(
self.style.ERROR('Failed to create OAuth2 application')
)
except Exception as e:
self.stdout.write(
self.style.ERROR(f'Error initializing OAuth2 Provider: {e}')
)
raise

View File

@@ -6,6 +6,7 @@ import time
import uuid
from functools import partial
from typing import Callable
from werkzeug.local import Local
from django.conf import settings
from django.contrib import auth
@@ -16,6 +17,7 @@ from django.contrib.auth import (
from django.contrib.auth import get_user_model
from django.core.cache import cache
from django.core.exceptions import ImproperlyConfigured
from django.db.models import Q
from django.shortcuts import reverse, redirect, get_object_or_404
from django.utils.http import urlencode
from django.utils.translation import gettext as _
@@ -31,6 +33,87 @@ from .signals import post_auth_success, post_auth_failed
logger = get_logger(__name__)
# 模块级别的线程上下文,用于 authenticate 函数中标记当前线程
_auth_thread_context = Local()
# 保存 Django 原始的 get_or_create 方法(在模块加载时保存一次)
def _save_original_get_or_create():
"""保存 Django 原始的 get_or_create 方法"""
from django.contrib.auth import get_user_model as get_user_model_func
UserModel = get_user_model_func()
return UserModel.objects.get_or_create
_django_original_get_or_create = _save_original_get_or_create()
class OnlyAllowExistUserAuthError(Exception):
pass
def _authenticate_context(func):
"""
装饰器:管理 authenticate 函数的执行上下文
功能:
1. 执行前:
- 在线程本地存储中标记当前正在执行 authenticate
- 临时替换 UserModel.objects.get_or_create 方法
2. 执行后:
- 清理线程本地存储标记
- 恢复 get_or_create 为 Django 原始方法
作用:
- 确保 get_or_create 行为仅在 authenticate 生命周期内生效
- 支持 ONLY_ALLOW_EXIST_USER_AUTH 配置的线程安全实现
- 防止跨请求或跨线程的状态污染
"""
from functools import wraps
@wraps(func)
def wrapper(request=None, **credentials):
from django.contrib.auth import get_user_model
UserModel = get_user_model()
def custom_get_or_create(*args, **kwargs):
create_username = kwargs.get('username')
logger.debug(f"get_or_create: thread_id={threading.get_ident()}, username={create_username}")
# 如果当前线程正在执行 authenticate 且仅允许已存在用户认证,则提前判断用户是否存在
if (
getattr(_auth_thread_context, 'in_authenticate', False) and
settings.ONLY_ALLOW_EXIST_USER_AUTH
):
try:
UserModel.objects.get(username=create_username)
except UserModel.DoesNotExist:
raise OnlyAllowExistUserAuthError
# 调用 Django 原始方法(已是绑定方法,直接传参)
return _django_original_get_or_create(*args, **kwargs)
try:
# 执行前:设置线程上下文和 monkey-patch
setattr(_auth_thread_context, 'in_authenticate', True)
UserModel.objects.get_or_create = custom_get_or_create
# 执行原函数
return func(request, **credentials)
finally:
# 执行后:清理线程上下文和恢复原始方法
try:
if hasattr(_auth_thread_context, 'in_authenticate'):
delattr(_auth_thread_context, 'in_authenticate')
except Exception:
pass
try:
UserModel.objects.get_or_create = _django_original_get_or_create
except Exception:
pass
return wrapper
def _get_backends(return_tuples=False):
backends = []
@@ -48,39 +131,16 @@ def _get_backends(return_tuples=False):
return backends
class OnlyAllowExistUserAuthError(Exception):
pass
auth._get_backends = _get_backends
@_authenticate_context
def authenticate(request=None, **credentials):
"""
If the given credentials are valid, return a User object.
之所以 hack 这个 authenticate
"""
UserModel = get_user_model()
original_get_or_create = UserModel.objects.get_or_create
thread_local = threading.local()
thread_local.thread_id = threading.get_ident()
def custom_get_or_create(self, *args, **kwargs):
logger.debug(f"get_or_create: thread_id={threading.get_ident()}, username={username}")
if threading.get_ident() != thread_local.thread_id or not settings.ONLY_ALLOW_EXIST_USER_AUTH:
return original_get_or_create(*args, **kwargs)
create_username = kwargs.get('username')
try:
UserModel.objects.get(username=create_username)
except UserModel.DoesNotExist:
raise OnlyAllowExistUserAuthError
return original_get_or_create(*args, **kwargs)
username = credentials.get('username')
temp_user = None
username = credentials.get('username')
for backend, backend_path in _get_backends(return_tuples=True):
# 检查用户名是否允许认证 (预先检查,不浪费认证时间)
logger.info('Try using auth backend: {}'.format(str(backend)))
@@ -94,27 +154,28 @@ def authenticate(request=None, **credentials):
except TypeError:
# This backend doesn't accept these credentials as arguments. Try the next one.
continue
try:
UserModel.objects.get_or_create = custom_get_or_create.__get__(UserModel.objects)
user = backend.authenticate(request, **credentials)
except PermissionDenied:
# This backend says to stop in our tracks - this user should not be allowed in at all.
break
except OnlyAllowExistUserAuthError:
request.error_message = _(
'''The administrator has enabled "Only allow existing users to log in",
and the current user is not in the user list. Please contact the administrator.'''
)
if request:
request.error_message = _(
'''The administrator has enabled "Only allow existing users to log in",
and the current user is not in the user list. Please contact the administrator.'''
)
continue
finally:
UserModel.objects.get_or_create = original_get_or_create
if user is None:
continue
if not user.is_valid:
temp_user = user
temp_user.backend = backend_path
request.error_message = _('User is invalid')
if request:
request.error_message = _('User is invalid')
return temp_user
# 检查用户是否允许认证
@@ -129,8 +190,11 @@ def authenticate(request=None, **credentials):
else:
if temp_user is not None:
source_display = temp_user.source_display
request.error_message = _('''The administrator has enabled 'Only allow login from user source'.
The current user source is {}. Please contact the administrator.''').format(source_display)
if request:
request.error_message = _(
''' The administrator has enabled 'Only allow login from user source'.
The current user source is {}. Please contact the administrator. '''
).format(source_display)
return temp_user
# The credentials supplied are invalid to all backends, fire signal
@@ -228,7 +292,8 @@ class AuthPreCheckMixin:
if not settings.ONLY_ALLOW_EXIST_USER_AUTH:
return
exist = User.objects.filter(username=username).exists()
q = Q(username=username) | Q(email=username)
exist = User.objects.filter(q).exists()
if not exist:
logger.error(f"Only allow exist user auth, login failed: {username}")
self.raise_credential_error(errors.reason_user_not_exist)

View File

@@ -9,11 +9,12 @@ from common.utils import get_object_or_none, random_string
from users.models import User
from users.serializers import UserProfileSerializer
from ..models import AccessKey, TempToken
from oauth2_provider.models import get_access_token_model
__all__ = [
'AccessKeySerializer', 'BearerTokenSerializer',
'SSOTokenSerializer', 'TempTokenSerializer',
'AccessKeyCreateSerializer'
'AccessKeyCreateSerializer', 'AccessTokenSerializer',
]
@@ -114,3 +115,28 @@ class TempTokenSerializer(serializers.ModelSerializer):
token = TempToken(**kwargs)
token.save()
return token
class AccessTokenSerializer(serializers.ModelSerializer):
token_preview = serializers.SerializerMethodField(label=_("Token"))
class Meta:
model = get_access_token_model()
fields = [
'id', 'user', 'token_preview', 'is_valid',
'is_expired', 'expires', 'scope', 'created', 'updated',
]
read_only_fields = fields
extra_kwargs = {
'scope': { 'label': _('Scope') },
'expires': { 'label': _('Date expired') },
'updated': { 'label': _('Date updated') },
'created': { 'label': _('Date created') },
}
def get_token_preview(self, obj):
token_string = obj.token
if len(token_string) > 16:
return f"{token_string[:6]}...{token_string[-4:]}"
return "****"

View File

@@ -9,6 +9,8 @@ from audits.models import UserSession
from common.sessions.cache import user_session_manager
from .signals import post_auth_success, post_auth_failed, user_auth_failed, user_auth_success
from .backends.oauth2_provider.signal_handlers import *
@receiver(user_logged_in)
def on_user_auth_login_success(sender, user, request, **kwargs):
@@ -57,3 +59,4 @@ def on_user_login_success(sender, request, user, backend, create=False, **kwargs
def on_user_login_failed(sender, username, request, reason, backend, **kwargs):
request.session['auth_backend'] = backend
post_auth_failed.send(sender, username=username, request=request, reason=reason)

View File

@@ -47,3 +47,9 @@ def clean_expire_token():
count = TempToken.objects.filter(date_expired__lt=expired_time).delete()
logging.info('Deleted %d temporary tokens.', count[0])
logging.info('Cleaned expired temporary and connection tokens.')
@register_as_period_task(crontab=CRONTAB_AT_AM_TWO)
def clear_oauth2_provider_expired_tokens():
from oauth2_provider.models import clear_expired
clear_expired()

View File

@@ -16,6 +16,7 @@ router.register('super-connection-token', api.SuperConnectionTokenViewSet, 'supe
router.register('admin-connection-token', api.AdminConnectionTokenViewSet, 'admin-connection-token')
router.register('confirm', api.UserConfirmationViewSet, 'confirm')
router.register('ssh-key', api.SSHkeyViewSet, 'ssh-key')
router.register('access-tokens', api.AccessTokenViewSet, 'access-token')
urlpatterns = [
path('<str:backend>/qr/unbind/', api.QRUnBindForUserApi.as_view(), name='qr-unbind'),

View File

@@ -83,4 +83,6 @@ urlpatterns = [
path('oauth2/', include(('authentication.backends.oauth2.urls', 'authentication'), namespace='oauth2')),
path('captcha/', include('captcha.urls')),
path('oauth2-provider/', include(('authentication.backends.oauth2_provider.urls', 'authentication'), namespace='oauth2-provider'))
]

View File

@@ -11,6 +11,7 @@ from rest_framework.request import Request
from authentication import errors
from authentication.mixins import AuthMixin
from authentication.notifications import OAuthBindMessage
from authentication.decorators import post_save_next_to_session_if_guard_redirect
from common.utils import get_logger
from common.utils.common import get_request_ip
from common.utils.django import reverse, get_object_or_none
@@ -72,6 +73,7 @@ class BaseLoginCallbackView(AuthMixin, FlashMessageMixin, IMClientMixin, View):
return user, None
@post_save_next_to_session_if_guard_redirect
def get(self, request: Request):
code = request.GET.get('code')
redirect_url = request.GET.get('redirect_url')
@@ -110,8 +112,6 @@ class BaseLoginCallbackView(AuthMixin, FlashMessageMixin, IMClientMixin, View):
response = self.get_failed_response(login_url, title=msg, msg=msg)
return response
if redirect_url and 'next=client' in redirect_url:
self.request.META['QUERY_STRING'] += '&next=client'
return self.redirect_to_guard_view()

View File

@@ -10,6 +10,7 @@ from django.views import View
from rest_framework.exceptions import APIException
from rest_framework.permissions import AllowAny, IsAuthenticated
from authentication.decorators import post_save_next_to_session_if_guard_redirect, pre_save_next_to_session
from authentication import errors
from authentication.const import ConfirmType
from authentication.mixins import AuthMixin
@@ -24,7 +25,7 @@ from common.views.mixins import PermissionsMixin, UserConfirmRequiredExceptionMi
from users.models import User
from users.views import UserVerifyPasswordView
from .base import BaseLoginCallbackView
from .mixins import METAMixin, FlashMessageMixin
from .mixins import FlashMessageMixin
logger = get_logger(__file__)
@@ -171,20 +172,18 @@ class DingTalkEnableStartView(UserVerifyPasswordView):
return success_url
class DingTalkQRLoginView(DingTalkQRMixin, METAMixin, View):
class DingTalkQRLoginView(DingTalkQRMixin, View):
permission_classes = (AllowAny,)
@pre_save_next_to_session()
def get(self, request: HttpRequest):
redirect_url = request.GET.get('redirect_url') or reverse('index')
query_string = request.GET.urlencode()
redirect_url = f'{redirect_url}?{query_string}'
next_url = self.get_next_url_from_meta() or reverse('index')
next_url = safe_next_url(next_url, request=request)
redirect_uri = reverse('authentication:dingtalk-qr-login-callback', external=True)
redirect_uri += '?' + urlencode({
'redirect_url': redirect_url,
'next': next_url,
})
url = self.get_qr_url(redirect_uri)
@@ -210,6 +209,7 @@ class DingTalkQRLoginCallbackView(DingTalkQRMixin, BaseLoginCallbackView):
class DingTalkOAuthLoginView(DingTalkOAuthMixin, View):
permission_classes = (AllowAny,)
@pre_save_next_to_session()
def get(self, request: HttpRequest):
redirect_url = request.GET.get('redirect_url')
@@ -223,6 +223,7 @@ class DingTalkOAuthLoginView(DingTalkOAuthMixin, View):
class DingTalkOAuthLoginCallbackView(AuthMixin, DingTalkOAuthMixin, View):
permission_classes = (AllowAny,)
@post_save_next_to_session_if_guard_redirect
def get(self, request: HttpRequest):
code = request.GET.get('code')
redirect_url = request.GET.get('redirect_url')

View File

@@ -8,6 +8,7 @@ from django.views import View
from rest_framework.exceptions import APIException
from rest_framework.permissions import AllowAny, IsAuthenticated
from authentication.decorators import pre_save_next_to_session
from authentication.const import ConfirmType
from authentication.permissions import UserConfirmation
from common.sdk.im.feishu import URL
@@ -108,9 +109,12 @@ class FeiShuQRBindCallbackView(FeiShuQRMixin, BaseBindCallbackView):
class FeiShuQRLoginView(FeiShuQRMixin, View):
permission_classes = (AllowAny,)
@pre_save_next_to_session()
def get(self, request: HttpRequest):
redirect_url = request.GET.get('redirect_url') or reverse('index')
query_string = request.GET.urlencode()
query_string = request.GET.copy()
query_string.pop('next', None)
query_string = query_string.urlencode()
redirect_url = f'{redirect_url}?{query_string}'
redirect_uri = reverse(f'authentication:{self.category}-qr-login-callback', external=True)
redirect_uri += '?' + urlencode({

View File

@@ -29,7 +29,7 @@ from users.utils import (
redirect_user_first_login_or_index
)
from .. import mixins, errors
from ..const import RSA_PRIVATE_KEY, RSA_PUBLIC_KEY
from ..const import RSA_PRIVATE_KEY, RSA_PUBLIC_KEY, USER_LOGIN_GUARD_VIEW_REDIRECT_FIELD
from ..forms import get_user_login_form_cls
from ..utils import get_auth_methods
@@ -260,7 +260,7 @@ class UserLoginView(mixins.AuthMixin, UserLoginContextMixin, FormView):
class UserLoginGuardView(mixins.AuthMixin, RedirectView):
redirect_field_name = 'next'
redirect_field_name = USER_LOGIN_GUARD_VIEW_REDIRECT_FIELD
login_url = reverse_lazy('authentication:login')
login_mfa_url = reverse_lazy('authentication:login-mfa')
login_confirm_url = reverse_lazy('authentication:login-wait-confirm')

View File

@@ -67,6 +67,7 @@ class UserLoginMFAView(mixins.AuthMixin, FormView):
def get_context_data(self, **kwargs):
user = self.get_user_from_session()
mfa_context = self.get_user_mfa_context(user)
print(mfa_context)
kwargs.update(mfa_context)
return kwargs

View File

@@ -4,17 +4,6 @@ from django.utils.translation import gettext_lazy as _
from common.utils import FlashMessageUtil
class METAMixin:
def get_next_url_from_meta(self):
request_meta = self.request.META or {}
next_url = None
referer = request_meta.get('HTTP_REFERER', '')
next_url_item = referer.rsplit('next=', 1)
if len(next_url_item) > 1:
next_url = next_url_item[-1]
return next_url
class FlashMessageMixin:
@staticmethod
def get_response(redirect_url='', title='', msg='', m_type='message', interval=5):

View File

@@ -8,6 +8,7 @@ from rest_framework.exceptions import APIException
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.request import Request
from authentication.decorators import pre_save_next_to_session
from authentication.const import ConfirmType
from authentication.permissions import UserConfirmation
from common.sdk.im.slack import URL, SLACK_REDIRECT_URI_SESSION_KEY
@@ -96,9 +97,12 @@ class SlackEnableStartView(UserVerifyPasswordView):
class SlackQRLoginView(SlackMixin, View):
permission_classes = (AllowAny,)
@pre_save_next_to_session()
def get(self, request: Request):
redirect_url = request.GET.get('redirect_url') or reverse('index')
query_string = request.GET.urlencode()
query_string = request.GET.copy()
query_string.pop('next', None)
query_string = query_string.urlencode()
redirect_url = f'{redirect_url}?{query_string}'
redirect_uri = reverse('authentication:slack-qr-login-callback', external=True)
redirect_uri += '?' + urlencode({

View File

@@ -12,6 +12,7 @@ from authentication import errors
from authentication.const import ConfirmType
from authentication.mixins import AuthMixin
from authentication.permissions import UserConfirmation
from authentication.decorators import post_save_next_to_session_if_guard_redirect, pre_save_next_to_session
from common.sdk.im.wecom import URL
from common.sdk.im.wecom import WeCom, wecom_tool
from common.utils import get_logger
@@ -20,7 +21,7 @@ from common.views.mixins import UserConfirmRequiredExceptionMixin, PermissionsMi
from users.models import User
from users.views import UserVerifyPasswordView
from .base import BaseLoginCallbackView, BaseBindCallbackView
from .mixins import METAMixin, FlashMessageMixin
from .mixins import FlashMessageMixin
logger = get_logger(__file__)
@@ -115,19 +116,14 @@ class WeComEnableStartView(UserVerifyPasswordView):
return success_url
class WeComQRLoginView(WeComQRMixin, METAMixin, View):
class WeComQRLoginView(WeComQRMixin, View):
permission_classes = (AllowAny,)
@pre_save_next_to_session()
def get(self, request: HttpRequest):
redirect_url = request.GET.get('redirect_url') or reverse('index')
next_url = self.get_next_url_from_meta() or reverse('index')
next_url = safe_next_url(next_url, request=request)
redirect_uri = reverse('authentication:wecom-qr-login-callback', external=True)
redirect_uri += '?' + urlencode({
'redirect_url': redirect_url,
'next': next_url,
})
redirect_uri += '?' + urlencode({'redirect_url': redirect_url})
url = self.get_qr_url(redirect_uri)
return HttpResponseRedirect(url)
@@ -148,12 +144,11 @@ class WeComQRLoginCallbackView(WeComQRMixin, BaseLoginCallbackView):
class WeComOAuthLoginView(WeComOAuthMixin, View):
permission_classes = (AllowAny,)
@pre_save_next_to_session()
def get(self, request: HttpRequest):
redirect_url = request.GET.get('redirect_url')
redirect_uri = reverse('authentication:wecom-oauth-login-callback', external=True)
redirect_uri += '?' + urlencode({'redirect_url': redirect_url})
url = self.get_oauth_url(redirect_uri)
return HttpResponseRedirect(url)
@@ -161,6 +156,7 @@ class WeComOAuthLoginView(WeComOAuthMixin, View):
class WeComOAuthLoginCallbackView(AuthMixin, WeComOAuthMixin, View):
permission_classes = (AllowAny,)
@post_save_next_to_session_if_guard_redirect
def get(self, request: HttpRequest):
code = request.GET.get('code')
redirect_url = request.GET.get('redirect_url')

View File

@@ -183,6 +183,7 @@ class BaseFileRenderer(LogMixin, BaseRenderer):
for item in data:
row = []
for field in render_fields:
field._row = item
value = item.get(field.field_name)
value = self.render_value(field, value)
row.append(value)

View File

@@ -2,6 +2,7 @@
#
import datetime
import inspect
import sys
if sys.version_info.major == 3 and sys.version_info.minor >= 10:
@@ -334,6 +335,10 @@ class ES(object):
def is_keyword(props: dict, field: str) -> bool:
return props.get(field, {}).get("type", "keyword") == "keyword"
@staticmethod
def is_long(props: dict, field: str) -> bool:
return props.get(field, {}).get("type") == "long"
def get_query_body(self, **kwargs):
new_kwargs = {}
for k, v in kwargs.items():
@@ -361,10 +366,10 @@ class ES(object):
if index_in_field in kwargs:
index['values'] = kwargs[index_in_field]
mapping = self.es.indices.get_mapping(index=self.query_index)
mapping = self.es.indices.get_mapping(index=self.index)
props = (
mapping
.get(self.query_index, {})
.get(self.index, {})
.get('mappings', {})
.get('properties', {})
)
@@ -375,6 +380,9 @@ class ES(object):
if k in ("org_id", "session") and self.is_keyword(props, k):
exact[k] = v
elif self.is_long(props, k):
exact[k] = v
elif k in common_keyword_able:
exact[f"{k}.keyword"] = v

View File

@@ -15,6 +15,7 @@ class Device:
self.__load_driver(driver_path)
# open device
self.__open_device()
self.__reset_key_store()
def close(self):
if self.__device is None:
@@ -68,3 +69,12 @@ class Device:
if ret != 0:
raise PiicoError("open piico device failed", ret)
self.__device = device
def __reset_key_store(self):
if self._driver is None:
raise PiicoError("no driver loaded", 0)
if self.__device is None:
raise PiicoError("device not open", 0)
ret = self._driver.SPII_ResetModule(self.__device)
if ret != 0:
raise PiicoError("reset device failed", ret)

View File

@@ -192,6 +192,7 @@ class WeCom(RequestMixin):
class WeComTool(object):
WECOM_STATE_SESSION_KEY = '_wecom_state'
WECOM_STATE_VALUE = 'wecom'
WECOM_STATE_NEXT_URL_KEY = 'wecom_oauth_next_url'
@lazyproperty
def qr_cb_url(self):

View File

@@ -101,7 +101,7 @@ def get_ip_city(ip):
info = get_ip_city_by_ipip(ip)
if info:
city = info.get('city', _("Unknown"))
city = info.get('city') or _("Unknown")
country = info.get('country')
# 国内城市 并且 语言是中文就使用国内

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -280,7 +280,8 @@
"CACertificate": "Ca certificate",
"CAS": "CAS",
"CMPP2": "Cmpp v2.0",
"CTYunPrivate": "eCloud Private Cloud",
"CTYun": "State Cloud",
"CTYunPrivate": "State Cloud(Private)",
"CalculationResults": "Error in cron expression",
"CallRecords": "Call Records",
"CanDragSelect": "Select by dragging; Empty means all selected",
@@ -1634,5 +1635,8 @@
"selectedAssets": "Selected assets",
"setVariable": "Set variable",
"userId": "User ID",
"userName": "User name"
}
"userName": "User name",
"AccessToken": "Access tokens",
"AccessTokenTip": "Access Token is a temporary credential generated through the OAuth2 (Authorization Code Grant) flow using the JumpServer client, which is used to access protected resources.",
"Revoke": "Revoke"
}

View File

@@ -279,6 +279,7 @@
"CACertificate": "CA 证书",
"CAS": "CAS",
"CMPP2": "CMPP v2.0",
"CTYun": "天翼云",
"CTYunPrivate": "天翼私有云",
"CalculationResults": "cron 表达式错误",
"CallRecords": "调用记录",
@@ -1644,5 +1645,8 @@
"userId": "用户ID",
"userName": "用户名",
"Risk": "风险",
"selectFiles": "已选择选择{number}文件"
}
"selectFiles": "已选择选择{number}文件",
"AccessToken": "访问令牌",
"AccessTokenTip": "访问令牌是通过 JumpServer 客户端使用 OAuth2授权码授权流程生成的临时凭证用于访问受保护的资源。",
"Revoke": "撤销"
}

View File

@@ -381,7 +381,6 @@ class Config(dict):
'CAS_USERNAME_ATTRIBUTE': 'cas:user',
'CAS_APPLY_ATTRIBUTES_TO_USER': False,
'CAS_RENAME_ATTRIBUTES': {'cas:user': 'username'},
'CAS_CREATE_USER': True,
'CAS_ORG_IDS': [DEFAULT_ID],
'AUTH_SSO': False,
@@ -692,9 +691,9 @@ class Config(dict):
'FTP_FILE_MAX_STORE': 0,
# API 分页
'MAX_LIMIT_PER_PAGE': 10000, # 给导出用
'MAX_LIMIT_PER_PAGE': 10000, # 给导出用
'MAX_PAGE_SIZE': 1000,
'DEFAULT_PAGE_SIZE': 200, # 给没有请求分页的用
'DEFAULT_PAGE_SIZE': 200, # 给没有请求分页的用
'LIMIT_SUPER_PRIV': False,
@@ -702,15 +701,7 @@ class Config(dict):
'CHAT_AI_ENABLED': False,
'CHAT_AI_METHOD': 'api',
'CHAT_AI_EMBED_URL': '',
'CHAT_AI_TYPE': 'gpt',
'GPT_BASE_URL': '',
'GPT_API_KEY': '',
'GPT_PROXY': '',
'GPT_MODEL': 'gpt-4o-mini',
'DEEPSEEK_BASE_URL': '',
'DEEPSEEK_API_KEY': '',
'DEEPSEEK_PROXY': '',
'DEEPSEEK_MODEL': 'deepseek-chat',
'CHAT_AI_PROVIDERS': [],
'VIRTUAL_APP_ENABLED': False,
'FILE_UPLOAD_SIZE_LIMIT_MB': 200,
@@ -735,6 +726,10 @@ class Config(dict):
# MCP
'MCP_ENABLED': False,
# oauth2_provider settings
'OAUTH2_PROVIDER_ACCESS_TOKEN_EXPIRE_SECONDS': 60 * 60,
'OAUTH2_PROVIDER_REFRESH_TOKEN_EXPIRE_SECONDS': 60 * 60 * 24 * 7,
}
old_config_map = {

View File

@@ -151,8 +151,13 @@ class SafeRedirectMiddleware:
if not (300 <= response.status_code < 400):
return response
if request.resolver_match and request.resolver_match.namespace.startswith('authentication'):
# 认证相关的路由跳过验证core/auth/xxxx
if (
request.resolver_match and
request.resolver_match.namespace.startswith('authentication') and
not request.resolver_match.namespace.startswith('authentication:oauth2-provider')
):
# 认证相关的路由跳过验证 /core/auth/...,
# 但 oauth2-provider 除外, 因为它会重定向到第三方客户端, 希望给出更友好的提示
return response
location = response.get('Location')
if not location:

View File

@@ -159,7 +159,8 @@ CAS_CHECK_NEXT = lambda _next_page: True
CAS_USERNAME_ATTRIBUTE = CONFIG.CAS_USERNAME_ATTRIBUTE
CAS_APPLY_ATTRIBUTES_TO_USER = CONFIG.CAS_APPLY_ATTRIBUTES_TO_USER
CAS_RENAME_ATTRIBUTES = CONFIG.CAS_RENAME_ATTRIBUTES
CAS_CREATE_USER = CONFIG.CAS_CREATE_USER
CAS_CREATE_USER = True
CAS_STORE_NEXT = True
# SSO auth
AUTH_SSO = CONFIG.AUTH_SSO

View File

@@ -130,6 +130,7 @@ INSTALLED_APPS = [
'settings.apps.SettingsConfig',
'terminal.apps.TerminalConfig',
'audits.apps.AuditsConfig',
'oauth2_provider',
'authentication.apps.AuthenticationConfig', # authentication
'tickets.apps.TicketsConfig',
'acls.apps.AclsConfig',

View File

@@ -241,15 +241,7 @@ ASSET_SIZE = 'small'
CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED
CHAT_AI_METHOD = CONFIG.CHAT_AI_METHOD
CHAT_AI_EMBED_URL = CONFIG.CHAT_AI_EMBED_URL
CHAT_AI_TYPE = CONFIG.CHAT_AI_TYPE
GPT_BASE_URL = CONFIG.GPT_BASE_URL
GPT_API_KEY = CONFIG.GPT_API_KEY
GPT_PROXY = CONFIG.GPT_PROXY
GPT_MODEL = CONFIG.GPT_MODEL
DEEPSEEK_BASE_URL = CONFIG.DEEPSEEK_BASE_URL
DEEPSEEK_API_KEY = CONFIG.DEEPSEEK_API_KEY
DEEPSEEK_PROXY = CONFIG.DEEPSEEK_PROXY
DEEPSEEK_MODEL = CONFIG.DEEPSEEK_MODEL
CHAT_AI_DEFAULT_PROVIDER = CONFIG.CHAT_AI_DEFAULT_PROVIDER
VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED
@@ -268,4 +260,6 @@ LOKI_BASE_URL = CONFIG.LOKI_BASE_URL
TOOL_USER_ENABLED = CONFIG.TOOL_USER_ENABLED
SUGGESTION_LIMIT = CONFIG.SUGGESTION_LIMIT
MCP_ENABLED = CONFIG.MCP_ENABLED
MCP_ENABLED = CONFIG.MCP_ENABLED
CHAT_AI_PROVIDERS = CONFIG.CHAT_AI_PROVIDERS

View File

@@ -30,10 +30,11 @@ REST_FRAMEWORK = {
),
'DEFAULT_AUTHENTICATION_CLASSES': (
# 'rest_framework.authentication.BasicAuthentication',
'authentication.backends.drf.AccessTokenAuthentication',
'authentication.backends.drf.PrivateTokenAuthentication',
'authentication.backends.drf.ServiceAuthentication',
'authentication.backends.drf.SignatureAuthentication',
'authentication.backends.drf.ServiceAuthentication',
'authentication.backends.drf.PrivateTokenAuthentication',
'authentication.backends.drf.AccessTokenAuthentication',
"oauth2_provider.contrib.rest_framework.OAuth2Authentication",
'authentication.backends.drf.SessionAuthentication',
),
'DEFAULT_FILTER_BACKENDS': (
@@ -222,3 +223,17 @@ PIICO_DRIVER_PATH = CONFIG.PIICO_DRIVER_PATH
LEAK_PASSWORD_DB_PATH = CONFIG.LEAK_PASSWORD_DB_PATH
JUMPSERVER_UPTIME = int(time.time())
# OAuth2 Provider settings
OAUTH2_PROVIDER = {
'ALLOWED_REDIRECT_URI_SCHEMES': ['https', 'jms'],
'PKCE_REQUIRED': True,
'ACCESS_TOKEN_EXPIRE_SECONDS': CONFIG.OAUTH2_PROVIDER_ACCESS_TOKEN_EXPIRE_SECONDS,
'REFRESH_TOKEN_EXPIRE_SECONDS': CONFIG.OAUTH2_PROVIDER_REFRESH_TOKEN_EXPIRE_SECONDS,
}
OAUTH2_PROVIDER_CLIENT_REDIRECT_URI = 'jms://auth/callback'
OAUTH2_PROVIDER_JUMPSERVER_CLIENT_NAME = 'JumpServer Client'
if CONFIG.DEBUG_DEV:
OAUTH2_PROVIDER['ALLOWED_REDIRECT_URI_SCHEMES'].append('http')
OAUTH2_PROVIDER_CLIENT_REDIRECT_URI += ' http://127.0.0.1:14876/auth/callback'

View File

@@ -148,6 +148,6 @@ class RedirectConfirm(TemplateView):
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
return False
if parsed.scheme not in ['http', 'https']:
if parsed.scheme not in ['http', 'https', 'jms']:
return False
return True

73
apps/perms/tree.py Normal file
View File

@@ -0,0 +1,73 @@
from collections import defaultdict
from django.db.models import Q, Count
from common.utils import get_logger
from users.models import User
from assets.tree.asset_tree import AssetTree, AssetTreeNode
from perms.utils.utils import UserPermUtil
__all__ = ['UserPermTree']
logger = get_logger(__name__)
class PermTreeNode(AssetTreeNode):
class Type:
# Neither a permission node nor a node with direct permission assets
BRIDGE = 'bridge'
# Node with direct permission
DN = 'dn'
# Node with only direct permission assets
DA = 'da'
def __init__(self, tp, _id, key, value, assets_count=0, assets=None):
super().__init__(_id, key, value, assets_count)
self.type = tp or self.Type.BRIDGE
def as_dict(self, simple=True):
data = super().as_dict(simple=simple)
data.update({
'type': self.type,
})
return data
class UserPermTree(AssetTree):
TreeNode = PermTreeNode
def __init__(self, user=None, assets_q_object=None, category=None, org=None, with_assets=False):
super().__init__(
assets_q_object=assets_q_object,
category=category,
org=org,
with_assets=with_assets,
full_tree=False
)
self._user: User = user
self._util = UserPermUtil(user, org=self._org)
def _make_assets_q_object(self):
q = super()._make_assets_q_object()
q_perm_assets = Q(id__in=self._util._user_direct_asset_ids)
q_perm_nodes = Q(node_id__in=self._util._user_direct_node_all_children_ids)
q = q & (q_perm_assets | q_perm_nodes)
return q
def _get_tree_node_data(self, node_id):
data = super()._get_tree_node_data(node_id)
if node_id in self._util._user_direct_node_all_children_ids:
tp = PermTreeNode.Type.DN
elif self._nodes_assets_count_mapper.get(node_id, 0) > 0:
tp = PermTreeNode.Type.DA
else:
tp = PermTreeNode.Type.BRIDGE
data.update({ 'tp': tp })
return data
def print(self, simple=True, count=10):
self._util.print()
super().print(simple=simple, count=count)

138
apps/perms/utils/utils.py Normal file
View File

@@ -0,0 +1,138 @@
from django.db.models import Q
from common.utils import timeit, lazyproperty, get_logger, is_uuid
from orgs.utils import current_org
from users.models import User
from assets.models import Node, Asset
from perms.models import AssetPermission
logger = get_logger(__name__)
__all__ = ['UserPermUtil']
class UserPermUtil(object):
UserGroupThrough = User.groups.through
PermUserThrough = AssetPermission.users.through
PermUserGroupThrough = AssetPermission.user_groups.through
PermAssetThrough = AssetPermission.assets.through
PermNodeThrough = AssetPermission.nodes.through
def __init__(self, user, org=None):
self._user: User = user
self._org = org or current_org
self._user_permission_ids = set()
self._user_group_ids = set()
self._user_group_permission_ids = set()
self._user_all_permission_ids = set()
self._user_direct_asset_ids = set()
self._user_direct_node_ids = set()
self._user_direct_node_all_children_ids = set()
self._init()
def _init(self):
self._load_user_permission_ids()
self._load_user_group_ids()
self._load_user_group_permission_ids()
self._load_user_direct_asset_ids()
self._load_user_direct_node_ids()
self._load_user_direct_node_all_children_ids()
@timeit
def _load_user_permission_ids(self):
perm_ids = self.PermUserThrough.objects.filter(
user_id=self._user.id
).distinct('assetpermission_id').values_list('assetpermission_id', flat=True)
perm_ids = self._uuids_to_string(perm_ids)
self._user_permission_ids.update(perm_ids)
self._user_all_permission_ids.update(perm_ids)
@timeit
def _load_user_group_ids(self):
group_ids = self.UserGroupThrough.objects.filter(
user_id=self._user.id
).distinct('usergroup_id').values_list('usergroup_id', flat=True)
group_ids = self._uuids_to_string(group_ids)
self._user_group_ids.update(group_ids)
@timeit
def _load_user_group_permission_ids(self):
perm_ids = self.PermUserGroupThrough.objects.filter(
usergroup_id__in=self._user_group_ids
).distinct('assetpermission_id').values_list('assetpermission_id', flat=True)
perm_ids = self._uuids_to_string(perm_ids)
self._user_group_permission_ids.update(perm_ids)
self._user_all_permission_ids.update(perm_ids)
@timeit
def _load_user_direct_asset_ids(self):
asset_ids = self.PermAssetThrough.objects.filter(
assetpermission_id__in=self._user_all_permission_ids
).distinct('asset_id').values_list('asset_id', flat=True)
asset_ids = self._uuids_to_string(asset_ids)
self._user_direct_asset_ids.update(asset_ids)
@timeit
def _load_user_direct_node_ids(self):
node_ids = self.PermNodeThrough.objects.filter(
assetpermission_id__in=self._user_all_permission_ids
).distinct('node_id').values_list('node_id', flat=True)
node_ids = self._uuids_to_string(node_ids)
self._user_direct_node_ids.update(node_ids)
@timeit
def _load_user_direct_node_all_children_ids(self):
nid_key_pairs = Node.objects.filter(org_id=self._org.id).values_list('id', 'key')
nid_key_mapper = { str(nid): key for nid, key in nid_key_pairs }
dn_keys = [ nid_key_mapper[nid] for nid in self._user_direct_node_ids ]
def has_ancestor_in_direct_nodes(key: str) -> bool:
ancestor_keys = [ ':'.join(key.split(':')[:i]) for i in range(1, key.count(':') + 1) ]
return bool(set(ancestor_keys) & set(dn_keys))
dn_children_ids = [ nid for nid, key in nid_key_mapper.items() if has_ancestor_in_direct_nodes(key) ]
self._user_direct_node_all_children_ids.update(self._user_direct_node_ids)
self._user_direct_node_all_children_ids.update(dn_children_ids)
def get_node_assets(self, node: Node):
''' 获取节点下授权的直接资产, Luna 页面展开时需要 '''
q = Q(node_id=node.id)
if str(node.id) not in self._user_direct_node_all_children_ids:
q &= Q(id__in=self._user_direct_asset_ids)
assets = Asset.objects.filter(q)
return assets
def get_node_all_assets(self, node: Node):
''' 获取节点及其子节点下所有授权资产, 测试时需要 '''
if str(node.id) in self._user_direct_node_all_children_ids:
assets = node.get_all_assets()
return assets
children_ids = node.get_all_children(with_self=True).values_list('id', flat=True)
children_ids = self._uuids_to_string(children_ids)
dn_all_nodes_ids = set(children_ids) & self._user_direct_node_all_children_ids
other_nodes_ids = set(children_ids) - dn_all_nodes_ids
q = Q(node_id__in=dn_all_nodes_ids)
q |= Q(node_id__in=other_nodes_ids) & Q(id__in=self._user_direct_asset_ids)
assets = Asset.objects.filter(q)
return assets
def _uuids_to_string(self, uuids):
return [ str(u) for u in uuids ]
def print(self):
print('user_perm_tree:', self._user.username)
print('user_permission_ids_count:', len(self._user_permission_ids))
print('user_group_ids_count:', len(self._user_group_ids))
print('user_group_permission_ids_count:', len(self._user_permission_ids) - len(self._user_group_ids))
print('user_all_permission_ids_count:', len(self._user_all_permission_ids))
print('user_direct_asset_ids_count:', len(self._user_direct_asset_ids))
print('user_direct_node_ids_count:', len(self._user_direct_node_ids))
print('user_direct_node_all_children_ids_count:', len(self._user_direct_node_all_children_ids))

View File

@@ -133,6 +133,11 @@ exclude_permissions = (
('terminal', 'session', 'delete,change', 'command'),
('applications', '*', '*', '*'),
('settings', 'chatprompt', 'add,delete,change', 'chatprompt'),
('oauth2_provider', 'grant', '*', '*'),
('oauth2_provider', 'refreshtoken', '*', '*'),
('oauth2_provider', 'idtoken', '*', '*'),
('oauth2_provider', 'application', '*', '*'),
('oauth2_provider', 'accesstoken', 'add,change', 'accesstoken')
)
only_system_permissions = (
@@ -160,6 +165,7 @@ only_system_permissions = (
('authentication', 'temptoken', '*', '*'),
('authentication', 'passkey', '*', '*'),
('authentication', 'ssotoken', '*', '*'),
('oauth2_provider', 'accesstoken', '*', '*'),
('tickets', '*', '*', '*'),
('orgs', 'organization', 'view', 'rootorg'),
('terminal', 'applet', '*', '*'),

View File

@@ -129,6 +129,7 @@ special_pid_mapper = {
"rbac.view_systemtools": "view_workbench",
'tickets.view_ticket': 'tickets',
"audits.joblog": "job_audit",
'oauth2_provider.accesstoken': 'authentication',
}
special_setting_pid_mapper = {
@@ -184,6 +185,11 @@ verbose_name_mapper = {
'tickets.view_ticket': _("Ticket"),
'settings.setting': _("Common setting"),
'rbac.view_permission': _('View permission tree'),
'authentication.passkey': _("Passkey"),
'oauth2_provider.accesstoken': _("Access token"),
'oauth2_provider.view_accesstoken': _("View access token"),
'oauth2_provider.delete_accesstoken': _("Revoke access token"),
}
xpack_nodes = [

View File

@@ -1,98 +1,10 @@
import httpx
import openai
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from rest_framework import status
from rest_framework.generics import GenericAPIView
from rest_framework.response import Response
from common.api import JMSModelViewSet
from common.permissions import IsValidUser, OnlySuperUser
from .. import serializers
from ..const import ChatAITypeChoices
from ..models import ChatPrompt
from ..prompt import DefaultChatPrompt
class ChatAITestingAPI(GenericAPIView):
serializer_class = serializers.ChatAISettingSerializer
rbac_perms = {
'POST': 'settings.change_chatai'
}
def get_config(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
data = self.serializer_class().data
data.update(serializer.validated_data)
for k, v in data.items():
if v:
continue
# 页面没有传递值, 从 settings 中获取
data[k] = getattr(settings, k, None)
return data
def post(self, request):
config = self.get_config(request)
chat_ai_enabled = config['CHAT_AI_ENABLED']
if not chat_ai_enabled:
return Response(
status=status.HTTP_400_BAD_REQUEST,
data={'msg': _('Chat AI is not enabled')}
)
tp = config['CHAT_AI_TYPE']
if tp == ChatAITypeChoices.gpt:
url = config['GPT_BASE_URL']
api_key = config['GPT_API_KEY']
proxy = config['GPT_PROXY']
model = config['GPT_MODEL']
else:
url = config['DEEPSEEK_BASE_URL']
api_key = config['DEEPSEEK_API_KEY']
proxy = config['DEEPSEEK_PROXY']
model = config['DEEPSEEK_MODEL']
kwargs = {
'base_url': url or None,
'api_key': api_key,
}
try:
if proxy:
kwargs['http_client'] = httpx.Client(
proxies=proxy,
transport=httpx.HTTPTransport(local_address='0.0.0.0')
)
client = openai.OpenAI(**kwargs)
ok = False
error = ''
client.chat.completions.create(
messages=[
{
"role": "user",
"content": "Say this is a test",
}
],
model=model,
)
ok = True
except openai.APIConnectionError as e:
error = str(e.__cause__) # an underlying Exception, likely raised within httpx.
except openai.APIStatusError as e:
error = str(e.message)
except Exception as e:
ok, error = False, str(e)
if ok:
_status, msg = status.HTTP_200_OK, _('Test success')
else:
_status, msg = status.HTTP_400_BAD_REQUEST, error
return Response(status=_status, data={'msg': msg})
class ChatPromptViewSet(JMSModelViewSet):
serializer_classes = {
'default': serializers.ChatPromptSerializer,

View File

@@ -43,6 +43,7 @@ class ComponentI18nApi(RetrieveAPIView):
if not lang:
return Response(data)
lang = lang.lower()
if lang not in dict(Language.get_code_mapper()).keys():
lang = 'en'

View File

@@ -154,7 +154,10 @@ class SettingsApi(generics.RetrieveUpdateAPIView):
def parse_serializer_data(self, serializer):
data = []
fields = self.get_fields()
encrypted_items = [name for name, field in fields.items() if field.write_only]
encrypted_items = [
name for name, field in fields.items()
if field.write_only or getattr(field, 'encrypted', False)
]
category = self.request.query_params.get('category', '')
for name, value in serializer.validated_data.items():
encrypted = name in encrypted_items

View File

@@ -14,18 +14,5 @@ class ChatAIMethodChoices(TextChoices):
class ChatAITypeChoices(TextChoices):
gpt = 'gpt', 'GPT'
deep_seek = 'deep-seek', 'DeepSeek'
class GPTModelChoices(TextChoices):
gpt_4o_mini = 'gpt-4o-mini', 'gpt-4o-mini'
gpt_4o = 'gpt-4o', 'gpt-4o'
o3_mini = 'o3-mini', 'o3-mini'
o1_mini = 'o1-mini', 'o1-mini'
o1 = 'o1', 'o1'
class DeepSeekModelChoices(TextChoices):
deepseek_chat = 'deepseek-chat', 'DeepSeek-V3'
deepseek_reasoner = 'deepseek-reasoner', 'DeepSeek-R1'
openai = 'openai', 'Openai'
ollama = 'ollama', 'Ollama'

View File

@@ -0,0 +1,23 @@
# Generated by Django 4.1.13 on 2025-11-27 02:54
from django.db import migrations, connections
def refresh_pg_collation(apps, schema_editor):
for alias, conn in connections.databases.items():
if connections[alias].vendor == "postgresql":
dbname = connections[alias].settings_dict["NAME"]
connections[alias].cursor().execute(
f'ALTER DATABASE "{dbname}" REFRESH COLLATION VERSION;'
)
print(f"Refreshed postgresql collation version for database: {dbname} successfully.")
class Migration(migrations.Migration):
dependencies = [
('settings', '0002_leakpasswords'),
]
operations = [
migrations.RunPython(refresh_pg_collation, migrations.RunPython.noop),
]

View File

@@ -1,6 +1,7 @@
import json
import os
import shutil
from typing import Any, Dict, List
from django.conf import settings
from django.core.files.base import ContentFile
@@ -14,7 +15,6 @@ from rest_framework.utils.encoders import JSONEncoder
from common.db.models import JMSBaseModel
from common.db.utils import Encryptor
from common.utils import get_logger
from .const import ChatAITypeChoices
from .signals import setting_changed
logger = get_logger(__name__)
@@ -196,20 +196,25 @@ class ChatPrompt(JMSBaseModel):
return self.name
def get_chatai_data():
data = {
'url': settings.GPT_BASE_URL,
'api_key': settings.GPT_API_KEY,
'proxy': settings.GPT_PROXY,
'model': settings.GPT_MODEL,
}
if settings.CHAT_AI_TYPE != ChatAITypeChoices.gpt:
data['url'] = settings.DEEPSEEK_BASE_URL
data['api_key'] = settings.DEEPSEEK_API_KEY
data['proxy'] = settings.DEEPSEEK_PROXY
data['model'] = settings.DEEPSEEK_MODEL
def get_chatai_data() -> Dict[str, Any]:
raw_providers = settings.CHAT_AI_PROVIDERS
providers: List[dict] = [p for p in raw_providers if isinstance(p, dict)]
return data
if not providers:
return {}
selected = next(
(p for p in providers if p.get('is_assistant')),
providers[0],
)
return {
'url': selected.get('base_url'),
'api_key': selected.get('api_key'),
'proxy': selected.get('proxy'),
'model': selected.get('model'),
'name': selected.get('name'),
}
def init_sqlite_db():

View File

@@ -37,11 +37,4 @@ class CASSettingSerializer(serializers.Serializer):
"and the `value` is the JumpServer user attribute name"
)
)
CAS_CREATE_USER = serializers.BooleanField(
required=False, label=_('Create user'),
help_text=_(
'After successful user authentication, if the user does not exist, '
'automatically create the user'
)
)
CAS_ORG_IDS = OrgListField()
CAS_ORG_IDS = OrgListField()

View File

@@ -10,11 +10,12 @@ from common.utils import date_expired_default
__all__ = [
'AnnouncementSettingSerializer', 'OpsSettingSerializer', 'VaultSettingSerializer',
'HashicorpKVSerializer', 'AzureKVSerializer', 'TicketSettingSerializer',
'ChatAISettingSerializer', 'VirtualAppSerializer', 'AmazonSMSerializer',
'ChatAIProviderSerializer', 'ChatAISettingSerializer',
'VirtualAppSerializer', 'AmazonSMSerializer',
]
from settings.const import (
ChatAITypeChoices, GPTModelChoices, DeepSeekModelChoices, ChatAIMethodChoices
ChatAITypeChoices, ChatAIMethodChoices
)
@@ -120,6 +121,29 @@ class AmazonSMSerializer(serializers.Serializer):
)
class ChatAIProviderListSerializer(serializers.ListSerializer):
# 标记整个列表需要加密存储,避免明文保存 API Key
encrypted = True
class ChatAIProviderSerializer(serializers.Serializer):
type = serializers.ChoiceField(
default=ChatAITypeChoices.openai, choices=ChatAITypeChoices.choices,
label=_("Types"), required=False,
)
base_url = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.')
)
api_key = EncryptedField(
allow_blank=True, required=False, label=_('API Key'),
)
proxy = serializers.CharField(
allow_blank=True, required=False, label=_('Proxy'),
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
)
class ChatAISettingSerializer(serializers.Serializer):
PREFIX_TITLE = _('Chat AI')
@@ -130,44 +154,14 @@ class ChatAISettingSerializer(serializers.Serializer):
default=ChatAIMethodChoices.api, choices=ChatAIMethodChoices.choices,
label=_("Method"), required=False,
)
CHAT_AI_PROVIDERS = ChatAIProviderListSerializer(
child=ChatAIProviderSerializer(),
allow_empty=True, required=False, default=list, label=_('Providers')
)
CHAT_AI_EMBED_URL = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.')
)
CHAT_AI_TYPE = serializers.ChoiceField(
default=ChatAITypeChoices.gpt, choices=ChatAITypeChoices.choices,
label=_("Types"), required=False,
)
GPT_BASE_URL = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.')
)
GPT_API_KEY = EncryptedField(
allow_blank=True, required=False, label=_('API Key'),
)
GPT_PROXY = serializers.CharField(
allow_blank=True, required=False, label=_('Proxy'),
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
)
GPT_MODEL = serializers.ChoiceField(
default=GPTModelChoices.gpt_4o_mini, choices=GPTModelChoices.choices,
label=_("GPT Model"), required=False,
)
DEEPSEEK_BASE_URL = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.')
)
DEEPSEEK_API_KEY = EncryptedField(
allow_blank=True, required=False, label=_('API Key'),
)
DEEPSEEK_PROXY = serializers.CharField(
allow_blank=True, required=False, label=_('Proxy'),
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
)
DEEPSEEK_MODEL = serializers.ChoiceField(
default=DeepSeekModelChoices.deepseek_chat, choices=DeepSeekModelChoices.choices,
label=_("DeepSeek Model"), required=False,
)
class TicketSettingSerializer(serializers.Serializer):

View File

@@ -73,8 +73,6 @@ class PrivateSettingSerializer(PublicSettingSerializer):
CHAT_AI_ENABLED = serializers.BooleanField()
CHAT_AI_METHOD = serializers.CharField()
CHAT_AI_EMBED_URL = serializers.CharField()
CHAT_AI_TYPE = serializers.CharField()
GPT_MODEL = serializers.CharField()
FILE_UPLOAD_SIZE_LIMIT_MB = serializers.IntegerField()
FTP_FILE_MAX_STORE = serializers.IntegerField()
LOKI_LOG_ENABLED = serializers.BooleanField()

View File

@@ -21,7 +21,6 @@ urlpatterns = [
path('sms/<str:backend>/testing/', api.SMSTestingAPI.as_view(), name='sms-testing'),
path('sms/backend/', api.SMSBackendAPI.as_view(), name='sms-backend'),
path('vault/<str:backend>/testing/', api.VaultTestingAPI.as_view(), name='vault-testing'),
path('chatai/testing/', api.ChatAITestingAPI.as_view(), name='chatai-testing'),
path('vault/sync/', api.VaultSyncDataAPI.as_view(), name='vault-sync'),
path('security/block-ip/', api.BlockIPSecurityAPI.as_view(), name='block-ip'),
path('security/unlock-ip/', api.UnlockIPSecurityAPI.as_view(), name='unlock-ip'),

View File

@@ -524,13 +524,16 @@ class LDAPTestUtil(object):
# test server uri
def _check_server_uri(self):
if not any([self.config.server_uri.startswith('ldap://') or
self.config.server_uri.startswith('ldaps://')]):
if not (self.config.server_uri.startswith('ldap://') or
self.config.server_uri.startswith('ldaps://')):
err = _('ldap:// or ldaps:// protocol is used.')
raise LDAPInvalidServerError(err)
def _test_server_uri(self):
self._test_connection_bind()
# 这里测试 server uri 是否能连通, 不进行 bind 操作, 不需要传入 bind dn 和密码
server = Server(self.config.server_uri, use_ssl=self.config.use_ssl)
connection = Connection(server)
connection.open()
def test_server_uri(self):
try:

View File

@@ -6,10 +6,6 @@
{% block content %}
<style>
.alert.alert-msg {
background: #F5F5F7;
}
.target-url {
display: inline-block;
max-width: 100%;
@@ -18,29 +14,96 @@
overflow: hidden;
text-overflow: ellipsis;
vertical-align: middle;
cursor: pointer;
}
/* 重定向中的样式 */
.redirecting-container {
display: none;
}
.redirecting-container.show {
display: block;
}
.confirm-container {
transition: opacity 0.3s ease-out;
}
.confirm-container.hide {
display: none;
}
</style>
<div>
<p>
<div class="alert {% if error %} alert-danger {% else %} alert-info {% endif %}" id="messages">
{% trans 'You are about to be redirected to an external website. Please confirm that you trust this link: ' %}
<a class="target-url" href="{{ target_url }}">{{ target_url }}</a>
</div>
</p>
<!-- 确认内容 -->
<div class="confirm-container" id="confirmContainer">
<p>
<div class="alert {% if error %} alert-danger {% else %} alert-info {% endif %}" id="messages">
{% trans 'You are about to be redirected to an external website.' %}
<br/>
<br/>
{% trans 'Please confirm that you trust this link: ' %}
<br/>
<br/>
<a class="target-url" href="javascript:void(0)" onclick="handleRedirect(event)">{{ target_url }}</a>
</div>
</p>
<div class="row">
<div class="col-sm-3">
<a href="/" class="btn btn-default block full-width m-b">
{% trans 'Cancel' %}
</a>
</div>
<div class="col-sm-3">
<a href="{{ target_url }}" class="btn btn-primary block full-width m-b">
{% trans 'Confirm' %}
</a>
<div class="row">
<div class="col-sm-3">
<a href="/" class="btn btn-default block full-width m-b">
{% trans 'Back' %}
</a>
</div>
<div class="col-sm-3">
<a href="javascript:void(0)" onclick="handleRedirect(event)" class="btn btn-primary block full-width m-b">
{% trans 'Confirm' %}
</a>
</div>
</div>
</div>
<!-- 重定向中内容 -->
<div class="redirecting-container" id="redirectingContainer">
<p>
<div class="alert alert-info" id="messages">
{% trans 'Redirecting you to the Desktop App ( JumpServer Client )' %}
<br/>
<br/>
{% trans 'You can safely close this window and return to the application.' %}
</div>
</p>
</div>
</div>
<script>
const targetUrl = '{{ target_url }}';
// 判断是否是 jms:// 协议
function isJmsProtocol(url) {
return url.toLowerCase().startsWith('jms://');
}
function handleRedirect(event) {
// 如果有 event阻止默认行为
if (event) {
event.preventDefault();
}
if (isJmsProtocol(targetUrl)) {
// 隐藏确认内容
document.getElementById('confirmContainer').classList.add('hide');
// 显示重定向中
document.getElementById('redirectingContainer').classList.add('show');
}
// 延迟后执行跳转让用户看到加载动画
setTimeout(() => {
window.location.href = targetUrl;
}, 100);
}
</script>
{% endblock %}

View File

@@ -22,13 +22,13 @@ p {
{% trans 'JumpServerClient, currently used to launch the client' %}
</p>
<ul>
<li> <a href="/download/public/JumpServerClient_{{ CLIENT_VERSION }}_x64_en-US.msi">JumpServerClient-x64_en-US.msi</a></li>
<li> <a href="/download/public/JumpServerClient_{{ CLIENT_VERSION }}_x64-setup.exe">JumpServerClient-x64-setup.exe</a></li>
<li> <a href="/download/public/JumpServerClient_{{ CLIENT_VERSION }}_aarch64.dmg">JumpServerClient-aarch64.dmg</a></li>
<li> <a href="/download/public/JumpServerClient_{{ CLIENT_VERSION }}_x64.dmg">JumpServerClient-x64.dmg</a></li>
<li> <a href="/download/public/JumpServerClient_{{ CLIENT_VERSION }}_amd64.AppImage">JumpServerClient-amd64.AppImage</a></li>
<li> <a href="/download/public/JumpServerClient_{{ CLIENT_VERSION }}_amd64.deb">JumpServerClient-amd64.deb</a></li>
<li> <a href="/download/public/JumpServerClient-{{ CLIENT_VERSION }}-1.x86_64.rpm">JumpServerClient-1.x86_64.rpm</a></li>
<li> <a href="/download/public/JumpServerClient_{{ CLIENT_VERSION }}_x64_en-US.msi">JumpServerClient_{{ CLIENT_VERSION }}_x64_en-US.msi</a></li>
<li> <a href="/download/public/JumpServerClient_{{ CLIENT_VERSION }}_x64-setup.exe">JumpServerClient_{{ CLIENT_VERSION }}_x64-setup.exe</a></li>
<li> <a href="/download/public/JumpServerClient_{{ CLIENT_VERSION }}_aarch64.dmg">JumpServerClient_{{ CLIENT_VERSION }}_aarch64.dmg</a></li>
<li> <a href="/download/public/JumpServerClient_{{ CLIENT_VERSION }}_x64.dmg">JumpServerClient_{{ CLIENT_VERSION }}_x64.dmg</a></li>
<li> <a href="/download/public/JumpServerClient_{{ CLIENT_VERSION }}_amd64.AppImage">JumpServerClient_{{ CLIENT_VERSION }}_amd64.AppImage</a></li>
<li> <a href="/download/public/JumpServerClient_{{ CLIENT_VERSION }}_amd64.deb">JumpServerClient_{{ CLIENT_VERSION }}_amd64.deb</a></li>
<li> <a href="/download/public/JumpServerClient-{{ CLIENT_VERSION }}-1.x86_64.rpm">JumpServerClient-{{ CLIENT_VERSION }}-1.x86_64.rpm</a></li>
</ul>
</div>

View File

@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
#
from .applet import *
from .chat import *
from .component import *
from .session import *
from .virtualapp import *

View File

@@ -0,0 +1 @@
from .chat import *

View File

@@ -0,0 +1,15 @@
from common.api import JMSBulkModelViewSet
from terminal import serializers
from terminal.filters import ChatFilter
from terminal.models import Chat
__all__ = ['ChatViewSet']
class ChatViewSet(JMSBulkModelViewSet):
queryset = Chat.objects.all()
serializer_class = serializers.ChatSerializer
filterset_class = ChatFilter
search_fields = ['title']
ordering_fields = ['date_updated']
ordering = ['-date_updated']

View File

@@ -1,12 +1,11 @@
# -*- coding: utf-8 -*-
#
import pytz
from datetime import datetime
from common.utils import get_logger
from common.plugins.es import ES
import pytz
from common.plugins.es import ES
from common.utils import get_logger
logger = get_logger(__file__)
@@ -27,8 +26,8 @@ class CommandStore(ES):
"type": "long"
}
}
exact_fields = {}
fuzzy_fields = {'input', 'risk_level', 'user', 'asset', 'account'}
exact_fields = {'risk_level'}
fuzzy_fields = {'input', 'user', 'asset', 'account'}
match_fields = {'input'}
keyword_fields = {'session', 'org_id'}

View File

@@ -2,7 +2,7 @@ from django.db.models import QuerySet
from django_filters import rest_framework as filters
from orgs.utils import filter_org_queryset
from terminal.models import Command, CommandStorage, Session
from terminal.models import Command, CommandStorage, Session, Chat
class CommandFilter(filters.FilterSet):
@@ -79,7 +79,34 @@ class CommandStorageFilter(filters.FilterSet):
model = CommandStorage
fields = ['real', 'name', 'type', 'is_default']
def filter_real(self, queryset, name, value):
@staticmethod
def filter_real(queryset, name, value):
if value:
queryset = queryset.exclude(name='null')
return queryset
class ChatFilter(filters.FilterSet):
ids = filters.BooleanFilter(method='filter_ids')
folder_ids = filters.BooleanFilter(method='filter_folder_ids')
class Meta:
model = Chat
fields = [
'title', 'user_id', 'pinned', 'folder_id',
'archived', 'socket_id', 'share_id'
]
@staticmethod
def filter_ids(queryset, name, value):
ids = value.split(',')
queryset = queryset.filter(id__in=ids)
return queryset
@staticmethod
def filter_folder_ids(queryset, name, value):
ids = value.split(',')
queryset = queryset.filter(folder_id__in=ids)
return queryset

View File

@@ -0,0 +1,38 @@
# Generated by Django 4.1.13 on 2025-09-30 06:57
from django.db import migrations, models
import uuid
class Migration(migrations.Migration):
dependencies = [
('terminal', '0010_alter_command_risk_level_alter_session_login_from_and_more'),
]
operations = [
migrations.CreateModel(
name='Chat',
fields=[
('created_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Created by')),
('updated_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Updated by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('comment', models.TextField(blank=True, default='', verbose_name='Comment')),
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('title', models.CharField(max_length=256, verbose_name='Title')),
('chat', models.JSONField(default=dict, verbose_name='Chat')),
('meta', models.JSONField(default=dict, verbose_name='Meta')),
('pinned', models.BooleanField(default=False, verbose_name='Pinned')),
('archived', models.BooleanField(default=False, verbose_name='Archived')),
('share_id', models.CharField(blank=True, default='', max_length=36)),
('folder_id', models.CharField(blank=True, default='', max_length=36)),
('socket_id', models.CharField(blank=True, default='', max_length=36)),
('user_id', models.CharField(blank=True, db_index=True, default='', max_length=36)),
('session_info', models.JSONField(default=dict, verbose_name='Session Info')),
],
options={
'verbose_name': 'Chat',
},
),
]

View File

@@ -1,4 +1,5 @@
from .session import *
from .component import *
from .applet import *
from .chat import *
from .component import *
from .session import *
from .virtualapp import *

View File

@@ -0,0 +1 @@
from .chat import *

View File

@@ -0,0 +1,30 @@
from django.db import models
from django.utils.translation import gettext_lazy as _
from common.db.models import JMSBaseModel
from common.utils import get_logger
logger = get_logger(__name__)
__all__ = ['Chat']
class Chat(JMSBaseModel):
# id == session_id # 36 chars
title = models.CharField(max_length=256, verbose_name=_('Title'))
chat = models.JSONField(default=dict, verbose_name=_('Chat'))
meta = models.JSONField(default=dict, verbose_name=_('Meta'))
pinned = models.BooleanField(default=False, verbose_name=_('Pinned'))
archived = models.BooleanField(default=False, verbose_name=_('Archived'))
share_id = models.CharField(blank=True, default='', max_length=36)
folder_id = models.CharField(blank=True, default='', max_length=36)
socket_id = models.CharField(blank=True, default='', max_length=36)
user_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
session_info = models.JSONField(default=dict, verbose_name=_('Session Info'))
class Meta:
verbose_name = _('Chat')
def __str__(self):
return self.title

Some files were not shown because too many files have changed in this diff Show More