mirror of
https://github.com/jumpserver/jumpserver.git
synced 2025-12-18 18:12:37 +00:00
Compare commits
49 Commits
pr@dev@per
...
v5_refacto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56f720271a | ||
|
|
9755076f7f | ||
|
|
8d7abef191 | ||
|
|
aaa40722c4 | ||
|
|
ca39344937 | ||
|
|
4b9a8227c9 | ||
|
|
f362163af1 | ||
|
|
5f1ba56e56 | ||
|
|
2b1fdb937b | ||
|
|
1e754546f1 | ||
|
|
2ec71feafc | ||
|
|
02e8905330 | ||
|
|
8d68f5589b | ||
|
|
4df13fc384 | ||
|
|
78c1162028 | ||
|
|
14c2512b45 | ||
|
|
d6d7072da5 | ||
|
|
993bc36c5e | ||
|
|
ecff2ea07e | ||
|
|
ba70edf221 | ||
|
|
50050dff57 | ||
|
|
944226866c | ||
|
|
fe13221d88 | ||
|
|
ba17863892 | ||
|
|
065bfeda52 | ||
|
|
04af26500a | ||
|
|
e0388364c3 | ||
|
|
3c96480b0c | ||
|
|
95331a0c4b | ||
|
|
b8ecb703cf | ||
|
|
1a3f5e3f9a | ||
|
|
854396e8d5 | ||
|
|
ab08603e66 | ||
|
|
427fd3f72c | ||
|
|
0aba9ba120 | ||
|
|
045ca8807a | ||
|
|
19a68d8930 | ||
|
|
75ed02a2d2 | ||
|
|
f420dac49c | ||
|
|
1ee68134f2 | ||
|
|
937265db5d | ||
|
|
c611d5e88b | ||
|
|
883b6b6383 | ||
|
|
ac4c72064f | ||
|
|
dbf8360e27 | ||
|
|
150d7a09bc | ||
|
|
a7ed20e059 | ||
|
|
1b7b8e6f2e | ||
|
|
cd22fbce19 |
@@ -1,4 +1,4 @@
|
||||
FROM jumpserver/core-base:20251113_092612 AS stage-build
|
||||
FROM jumpserver/core-base:20251128_025056 AS stage-build
|
||||
|
||||
ARG VERSION
|
||||
|
||||
|
||||
126
apps/assets/migrations/0020_asset_node.py
Normal file
126
apps/assets/migrations/0020_asset_node.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# Generated by Django 4.1.13 on 2025-12-16 09:14
|
||||
|
||||
from django.db import migrations, models, transaction
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
def log(msg=''):
|
||||
print(f' -> {msg}')
|
||||
|
||||
|
||||
def ensure_asset_single_node(apps, schema_editor):
|
||||
print('')
|
||||
log('Checking that all assets are linked to only one node...')
|
||||
Asset = apps.get_model('assets', 'Asset')
|
||||
Through = Asset.nodes.through
|
||||
|
||||
assets_count_multi_nodes = Through.objects.values('asset_id').annotate(
|
||||
node_count=models.Count('node_id')
|
||||
).filter(node_count__gt=1).count()
|
||||
|
||||
if assets_count_multi_nodes > 0:
|
||||
raise Exception(
|
||||
f'There are {assets_count_multi_nodes} assets associated with more than one node. '
|
||||
'Please ensure each asset is linked to only one node before applying this migration.'
|
||||
)
|
||||
else:
|
||||
log('All assets are linked to only one node. Proceeding with the migration.')
|
||||
|
||||
|
||||
def ensure_asset_has_node(apps, schema_editor):
|
||||
log('Checking that all assets are linked to at least one node...')
|
||||
Asset = apps.get_model('assets', 'Asset')
|
||||
Through = Asset.nodes.through
|
||||
|
||||
asset_count = Asset.objects.count()
|
||||
through_asset_count = Through.objects.values('asset_id').count()
|
||||
|
||||
assets_count_without_node = asset_count - through_asset_count
|
||||
|
||||
if assets_count_without_node > 0:
|
||||
raise Exception(
|
||||
f'Some assets ({assets_count_without_node}) are not associated with any node. '
|
||||
'Please ensure all assets are linked to a node before applying this migration.'
|
||||
)
|
||||
else:
|
||||
log('All assets are linked to a node. Proceeding with the migration.')
|
||||
|
||||
|
||||
def migrate_asset_node_id_field(apps, schema_editor):
|
||||
log('Migrating node_id field for all assets...')
|
||||
|
||||
Asset = apps.get_model('assets', 'Asset')
|
||||
Through = Asset.nodes.through
|
||||
|
||||
assets = Asset.objects.filter(node_id__isnull=True)
|
||||
log (f'Found {assets.count()} assets to migrate.')
|
||||
|
||||
asset_node_mapper = {
|
||||
str(asset_id): str(node_id)
|
||||
for asset_id, node_id in Through.objects.values_list('asset_id', 'node_id')
|
||||
}
|
||||
# 测试
|
||||
asset_node_mapper.pop(None, None) # Remove any entries with None keys
|
||||
|
||||
for asset in assets:
|
||||
node_id = asset_node_mapper.get(str(asset.id))
|
||||
if not node_id:
|
||||
raise Exception(
|
||||
f'Asset (ID: {asset.id}) is not associated with any node. '
|
||||
'Cannot migrate node_id field.'
|
||||
)
|
||||
asset.node_id = node_id
|
||||
|
||||
with transaction.atomic():
|
||||
total = len(assets)
|
||||
batch_size = 5000
|
||||
|
||||
for i in range(0, total, batch_size):
|
||||
batch = assets[i:i+batch_size]
|
||||
start = i + 1
|
||||
end = min(i + batch_size, total)
|
||||
|
||||
for asset in batch:
|
||||
asset.save(update_fields=['node_id'])
|
||||
|
||||
log(f"Migrated {start}-{end}/{total} assets")
|
||||
|
||||
count = Asset.objects.filter(node_id__isnull=True).count()
|
||||
if count > 0:
|
||||
log('Warning: Some assets still have null node_id after migration.')
|
||||
raise Exception('Migration failed: Some assets have null node_id.')
|
||||
|
||||
count = Asset.objects.filter(node_id__isnull=False).count()
|
||||
log(f'Successfully migrated node_id for {count} assets.')
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('assets', '0019_alter_asset_connectivity'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(
|
||||
ensure_asset_single_node,
|
||||
reverse_code=migrations.RunPython.noop
|
||||
),
|
||||
migrations.RunPython(
|
||||
ensure_asset_has_node,
|
||||
reverse_code=migrations.RunPython.noop
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='asset',
|
||||
name='node',
|
||||
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, related_name='direct_assets', to='assets.node', verbose_name='Node'),
|
||||
),
|
||||
migrations.RunPython(
|
||||
migrate_asset_node_id_field,
|
||||
reverse_code=migrations.RunPython.noop
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='asset',
|
||||
name='node',
|
||||
field=models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='direct_assets', to='assets.node', verbose_name='Node'),
|
||||
),
|
||||
]
|
||||
@@ -172,6 +172,11 @@ class Asset(NodesRelationMixin, LabeledMixin, AbsConnectivity, JSONFilterMixin,
|
||||
"assets.Zone", null=True, blank=True, related_name='assets',
|
||||
verbose_name=_("Zone"), on_delete=models.SET_NULL
|
||||
)
|
||||
node = models.ForeignKey(
|
||||
'assets.Node', null=False, blank=False, on_delete=models.PROTECT,
|
||||
related_name='direct_assets', verbose_name=_("Node")
|
||||
)
|
||||
# TODO: 删除完代码中所有使用的地方后,再删除 nodes 字段,并将 node 字段的 related_name 改为 'assets'
|
||||
nodes = models.ManyToManyField(
|
||||
'assets.Node', default=default_node, related_name='assets', verbose_name=_("Nodes")
|
||||
)
|
||||
|
||||
@@ -186,6 +186,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
super().__init__(*args, **kwargs)
|
||||
self._init_field_choices()
|
||||
self._extract_accounts()
|
||||
self._set_platform()
|
||||
|
||||
def _extract_accounts(self):
|
||||
if not getattr(self, 'initial_data', None):
|
||||
@@ -217,6 +218,21 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
protocols_data = [{'name': p.name, 'port': p.port} for p in protocols]
|
||||
self.initial_data['protocols'] = protocols_data
|
||||
|
||||
def _set_platform(self):
|
||||
if not hasattr(self, 'initial_data'):
|
||||
return
|
||||
platform_id = self.initial_data.get('platform')
|
||||
if not platform_id:
|
||||
return
|
||||
|
||||
if isinstance(platform_id, int) or str(platform_id).isdigit() or not isinstance(platform_id, str):
|
||||
return
|
||||
|
||||
platform = Platform.objects.filter(name=platform_id).first()
|
||||
if not platform:
|
||||
return
|
||||
self.initial_data['platform'] = platform.id
|
||||
|
||||
def _init_field_choices(self):
|
||||
request = self.context.get('request')
|
||||
if not request:
|
||||
@@ -265,8 +281,10 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
|
||||
if not platform_id and self.instance:
|
||||
platform = self.instance.platform
|
||||
else:
|
||||
elif isinstance(platform_id, int):
|
||||
platform = Platform.objects.filter(id=platform_id).first()
|
||||
else:
|
||||
platform = Platform.objects.filter(name=platform_id).first()
|
||||
|
||||
if not platform:
|
||||
raise serializers.ValidationError({'platform': _("Platform not exist")})
|
||||
@@ -297,6 +315,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
|
||||
def is_valid(self, raise_exception=False):
|
||||
self._set_protocols_default()
|
||||
self._set_platform()
|
||||
return super().is_valid(raise_exception=raise_exception)
|
||||
|
||||
def validate_protocols(self, protocols_data):
|
||||
|
||||
1
apps/assets/tree/__init__.py
Normal file
1
apps/assets/tree/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .asset_tree import *
|
||||
88
apps/assets/tree/asset_tree.py
Normal file
88
apps/assets/tree/asset_tree.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from collections import defaultdict
|
||||
from django.db.models import Count
|
||||
|
||||
from orgs.utils import current_org
|
||||
from assets.models import Asset, Node
|
||||
from common.utils import get_logger, timeit
|
||||
|
||||
from .tree import TreeNode, Tree
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
__all__ = ['AssetTree']
|
||||
|
||||
|
||||
class AssetTreeNode(TreeNode):
|
||||
|
||||
def __init__(self, _id, key: str, value: str, assets_count: int=0):
|
||||
super().__init__(_id, key, value)
|
||||
self.assets_count = assets_count
|
||||
self.assets_count_total = 0
|
||||
|
||||
def as_dict(self, simple=True):
|
||||
base_dict = super().as_dict(simple=simple)
|
||||
base_dict.update({
|
||||
'assets_count_total': self.assets_count_total,
|
||||
})
|
||||
if not simple:
|
||||
base_dict.update({
|
||||
'assets_count': self.assets_count,
|
||||
})
|
||||
return base_dict
|
||||
|
||||
|
||||
class AssetTree(Tree):
|
||||
|
||||
def __init__(self, org=None):
|
||||
super().__init__()
|
||||
self._org = org or current_org()
|
||||
self._nodes_attr_mapper = defaultdict(dict)
|
||||
self._nodes_assets_count_mapper = defaultdict(int)
|
||||
|
||||
@timeit
|
||||
def build(self):
|
||||
self._load_nodes_attr_mapper()
|
||||
self._load_nodes_assets_count()
|
||||
self._init_tree()
|
||||
self._compute_assets_count_total()
|
||||
|
||||
@timeit
|
||||
def _load_nodes_attr_mapper(self):
|
||||
nodes = Node.objects.filter(org_id=self._org.id).values('id', 'key', 'value')
|
||||
# 保证节点按 key 顺序加载,以便后续构建树时父节点总在子节点前面
|
||||
nodes = sorted(nodes, key=lambda n: [int(i) for i in n['key'].split(':')])
|
||||
for node in list(nodes):
|
||||
node['id'] = str(node['id'])
|
||||
self._nodes_attr_mapper[node['id']] = node
|
||||
|
||||
@timeit
|
||||
def _load_nodes_assets_count(self):
|
||||
nodes_count = Asset.objects.filter(org_id=self._org.id).values('node_id').annotate(
|
||||
count=Count('id')
|
||||
).values('node_id', 'count')
|
||||
for nc in list(nodes_count):
|
||||
nc['node_id'] = str(nc['node_id'])
|
||||
self._nodes_assets_count_mapper[nc['node_id']] = nc['count']
|
||||
|
||||
@timeit
|
||||
def _init_tree(self):
|
||||
for nid, attr in self._nodes_attr_mapper.items():
|
||||
assets_count = self._nodes_assets_count_mapper.get(nid, 0)
|
||||
node = AssetTreeNode(
|
||||
_id=nid,
|
||||
key=attr['key'],
|
||||
value=attr['value'],
|
||||
assets_count=assets_count
|
||||
)
|
||||
self.add_node(node)
|
||||
|
||||
@timeit
|
||||
def _compute_assets_count_total(self):
|
||||
for node in reversed(list(self.nodes.values())):
|
||||
total = node.assets_count
|
||||
for child in node.children:
|
||||
child: AssetTreeNode
|
||||
total += child.assets_count_total
|
||||
node: AssetTreeNode
|
||||
node.assets_count_total = total
|
||||
126
apps/assets/tree/tree.py
Normal file
126
apps/assets/tree/tree.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from common.utils import get_logger, lazyproperty
|
||||
|
||||
|
||||
__all__ = ['TreeNode', 'Tree']
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TreeNode(object):
|
||||
|
||||
def __init__(self, _id, key, value):
|
||||
self.id = _id
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.children = []
|
||||
self.parent = None
|
||||
|
||||
@lazyproperty
|
||||
def parent_key(self):
|
||||
if self.is_root:
|
||||
return None
|
||||
return ':'.join(self.key.split(':')[:-1])
|
||||
|
||||
@property
|
||||
def is_root(self):
|
||||
return self.key.isdigit()
|
||||
|
||||
def add_child(self, child_node: 'TreeNode'):
|
||||
child_node.parent = self
|
||||
self.children.append(child_node)
|
||||
|
||||
def remove_child(self, child_node: 'TreeNode'):
|
||||
self.parent.children.remove(child_node)
|
||||
self.parent = None
|
||||
|
||||
@property
|
||||
def is_leaf(self):
|
||||
return len(self.children) == 0
|
||||
|
||||
@lazyproperty
|
||||
def level(self):
|
||||
return self.key.count(':') + 1
|
||||
|
||||
@property
|
||||
def children_count(self):
|
||||
return len(self.children)
|
||||
|
||||
def as_dict(self, simple=True):
|
||||
data = {
|
||||
'key': self.key,
|
||||
}
|
||||
if not simple:
|
||||
data.update({
|
||||
'id': self.id,
|
||||
'value': self.value,
|
||||
'level': self.level,
|
||||
'children_count': self.children_count,
|
||||
'is_root': self.is_root,
|
||||
'is_leaf': self.is_leaf,
|
||||
})
|
||||
return data
|
||||
|
||||
def print(self, simple=True):
|
||||
info = [f"{k}: {v}" for k, v in self.as_dict(simple=simple).items()]
|
||||
print(' | '.join(info))
|
||||
|
||||
|
||||
class Tree(object):
|
||||
|
||||
def __init__(self):
|
||||
self.root = None
|
||||
self.nodes: dict[TreeNode] = {}
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return len(self.nodes)
|
||||
|
||||
@property
|
||||
def is_empty(self):
|
||||
return self.size == 0
|
||||
|
||||
@property
|
||||
def depth(self):
|
||||
" 返回树的最大深度,以及对应的节点key "
|
||||
if self.is_empty:
|
||||
return 0, 0
|
||||
node = max(self.nodes.values(), key=lambda node: node.level)
|
||||
node: TreeNode
|
||||
logger.debug(f"Max depth node key: {node.key}")
|
||||
return node.level
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
" 返回树的最大宽度,以及对应的层级数 "
|
||||
if self.is_empty:
|
||||
return 0, 0
|
||||
node = max(self.nodes.values(), key=lambda node: node.children_count)
|
||||
node: TreeNode
|
||||
logger.debug(f"Max width level: {node.level + 1}")
|
||||
return node.children_count
|
||||
|
||||
def add_node(self, node: TreeNode):
|
||||
if node.is_root:
|
||||
self.root = node
|
||||
self.nodes[node.key] = node
|
||||
return
|
||||
parent = self.get_node(node.parent_key)
|
||||
if not parent:
|
||||
error = f""" Cannot add node {node.key}: parent key {node.parent_key} not found.
|
||||
Please ensure parent nodes are added before child nodes."""
|
||||
raise ValueError(error)
|
||||
parent.add_child(node)
|
||||
self.nodes[node.key] = node
|
||||
|
||||
def get_node(self, key: str) -> TreeNode:
|
||||
return self.nodes.get(key)
|
||||
|
||||
def print(self, count=10, simple=True):
|
||||
print('Tree root: ', getattr(self.root, 'key', 'No-root'))
|
||||
print('Tree size: ', self.size)
|
||||
print('Tree depth: ', self.depth)
|
||||
print('Tree width: ', self.width)
|
||||
|
||||
for n in list(self.nodes.values())[:count]:
|
||||
n: TreeNode
|
||||
n.print(simple=simple)
|
||||
@@ -43,7 +43,7 @@ from .serializers import (
|
||||
OperateLogSerializer, OperateLogActionDetailSerializer,
|
||||
PasswordChangeLogSerializer, ActivityUnionLogSerializer,
|
||||
FileSerializer, UserSessionSerializer, JobsAuditSerializer,
|
||||
ServiceAccessLogSerializer
|
||||
ServiceAccessLogSerializer, OperateLogFullSerializer
|
||||
)
|
||||
from .utils import construct_userlogin_usernames, record_operate_log_and_activity_log
|
||||
|
||||
@@ -256,7 +256,9 @@ class OperateLogViewSet(OrgReadonlyModelViewSet):
|
||||
def get_serializer_class(self):
|
||||
if self.is_action_detail:
|
||||
return OperateLogActionDetailSerializer
|
||||
return super().get_serializer_class()
|
||||
elif self.request.query_params.get('format'):
|
||||
return OperateLogFullSerializer
|
||||
return OperateLogSerializer
|
||||
|
||||
def get_queryset(self):
|
||||
current_org_id = str(current_org.id)
|
||||
|
||||
@@ -127,6 +127,21 @@ class OperateLogSerializer(BulkOrgResourceModelSerializer):
|
||||
return i18n_trans(instance.resource)
|
||||
|
||||
|
||||
class DiffFieldSerializer(serializers.JSONField):
|
||||
def to_file_representation(self, value):
|
||||
row = getattr(self, '_row') or {}
|
||||
attrs = {'diff': value, 'resource_type': row.get('resource_type')}
|
||||
instance = type('OperateLog', (), attrs)
|
||||
return OperateLogStore.convert_diff_friendly(instance)
|
||||
|
||||
|
||||
class OperateLogFullSerializer(OperateLogSerializer):
|
||||
diff = DiffFieldSerializer(label=_("Diff"))
|
||||
|
||||
class Meta(OperateLogSerializer.Meta):
|
||||
fields = OperateLogSerializer.Meta.fields + ['diff']
|
||||
|
||||
|
||||
class PasswordChangeLogSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = models.PasswordChangeLog
|
||||
|
||||
@@ -16,3 +16,4 @@ from .sso import *
|
||||
from .temp_token import *
|
||||
from .token import *
|
||||
from .face import *
|
||||
from .access_token import *
|
||||
|
||||
47
apps/authentication/api/access_token.py
Normal file
47
apps/authentication/api/access_token.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
|
||||
from oauth2_provider.models import get_access_token_model
|
||||
|
||||
from common.api import JMSModelViewSet
|
||||
from rbac.permissions import RBACPermission
|
||||
from ..serializers import AccessTokenSerializer
|
||||
|
||||
|
||||
AccessToken = get_access_token_model()
|
||||
|
||||
|
||||
class AccessTokenViewSet(JMSModelViewSet):
|
||||
"""
|
||||
OAuth2 Access Token 管理视图集
|
||||
用户只能查看和撤销自己的 access token
|
||||
"""
|
||||
serializer_class = AccessTokenSerializer
|
||||
permission_classes = [RBACPermission]
|
||||
http_method_names = ['get', 'options', 'delete']
|
||||
rbac_perms = {
|
||||
'revoke': 'oauth2_provider.delete_accesstoken',
|
||||
}
|
||||
|
||||
def get_queryset(self):
|
||||
"""只返回当前用户的 access token,按创建时间倒序"""
|
||||
return AccessToken.objects.filter(user=self.request.user).order_by('-created')
|
||||
|
||||
@action(methods=['DELETE'], detail=True, url_path='revoke')
|
||||
def revoke(self, request, *args, **kwargs):
|
||||
"""
|
||||
撤销 access token 及其关联的 refresh token
|
||||
如果 token 不存在或不属于当前用户,返回 404
|
||||
"""
|
||||
token = get_object_or_404(
|
||||
AccessToken.objects.filter(user=request.user),
|
||||
id=kwargs['pk']
|
||||
)
|
||||
# 优先撤销 refresh token,会自动撤销关联的 access token
|
||||
token_to_revoke = token.refresh_token if token.refresh_token else token
|
||||
token_to_revoke.revoke()
|
||||
return Response(status=HTTP_204_NO_CONTENT)
|
||||
@@ -1,6 +1,5 @@
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.backends import ModelBackend
|
||||
from django.views import View
|
||||
|
||||
from common.utils import get_logger
|
||||
from users.models import User
|
||||
@@ -66,11 +65,3 @@ class JMSBaseAuthBackend:
|
||||
class JMSModelBackend(JMSBaseAuthBackend, ModelBackend):
|
||||
def user_can_authenticate(self, user):
|
||||
return True
|
||||
|
||||
|
||||
class BaseAuthCallbackClientView(View):
|
||||
http_method_names = ['get']
|
||||
|
||||
def get(self, request):
|
||||
from authentication.views.utils import redirect_to_guard_view
|
||||
return redirect_to_guard_view(query_string='next=client')
|
||||
|
||||
@@ -1,51 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
|
||||
import threading
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import get_user_model
|
||||
from django_cas_ng.backends import CASBackend as _CASBackend
|
||||
|
||||
from common.utils import get_logger
|
||||
from ..base import JMSBaseAuthBackend
|
||||
|
||||
__all__ = ['CASBackend', 'CASUserDoesNotExist']
|
||||
__all__ = ['CASBackend']
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CASUserDoesNotExist(Exception):
|
||||
"""Exception raised when a CAS user does not exist."""
|
||||
pass
|
||||
|
||||
|
||||
class CASBackend(JMSBaseAuthBackend, _CASBackend):
|
||||
@staticmethod
|
||||
def is_enabled():
|
||||
return settings.AUTH_CAS
|
||||
|
||||
def authenticate(self, request, ticket, service):
|
||||
UserModel = get_user_model()
|
||||
manager = UserModel._default_manager
|
||||
original_get_by_natural_key = manager.get_by_natural_key
|
||||
thread_local = threading.local()
|
||||
thread_local.thread_id = threading.get_ident()
|
||||
logger.debug(f"CASBackend.authenticate: thread_id={thread_local.thread_id}")
|
||||
|
||||
def get_by_natural_key(self, username):
|
||||
logger.debug(f"CASBackend.get_by_natural_key: thread_id={threading.get_ident()}, username={username}")
|
||||
if threading.get_ident() != thread_local.thread_id:
|
||||
return original_get_by_natural_key(username)
|
||||
|
||||
try:
|
||||
user = original_get_by_natural_key(username)
|
||||
except UserModel.DoesNotExist:
|
||||
raise CASUserDoesNotExist(username)
|
||||
return user
|
||||
|
||||
try:
|
||||
manager.get_by_natural_key = get_by_natural_key.__get__(manager, type(manager))
|
||||
user = super().authenticate(request, ticket=ticket, service=service)
|
||||
finally:
|
||||
manager.get_by_natural_key = original_get_by_natural_key
|
||||
return user
|
||||
# 这里做个hack ,让父类始终走CAS_CREATE_USER=True的逻辑,然后调用 authentication/mixins.py 中的 custom_get_or_create 方法
|
||||
settings.CAS_CREATE_USER = True
|
||||
return super().authenticate(request, ticket, service)
|
||||
|
||||
@@ -3,11 +3,10 @@
|
||||
import django_cas_ng.views
|
||||
from django.urls import path
|
||||
|
||||
from .views import CASLoginView, CASCallbackClientView
|
||||
from .views import CASLoginView
|
||||
|
||||
urlpatterns = [
|
||||
path('login/', CASLoginView.as_view(), name='cas-login'),
|
||||
path('logout/', django_cas_ng.views.LogoutView.as_view(), name='cas-logout'),
|
||||
path('callback/', django_cas_ng.views.CallbackView.as_view(), name='cas-proxy-callback'),
|
||||
path('login/client', CASCallbackClientView.as_view(), name='cas-proxy-callback-client'),
|
||||
path('callback/', django_cas_ng.views.CallbackView.as_view(), name='cas-proxy-callback')
|
||||
]
|
||||
|
||||
@@ -3,31 +3,20 @@ from django.http import HttpResponseRedirect
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_cas_ng.views import LoginView
|
||||
|
||||
from authentication.backends.base import BaseAuthCallbackClientView
|
||||
from common.utils import FlashMessageUtil
|
||||
from .backends import CASUserDoesNotExist
|
||||
from authentication.views.mixins import FlashMessageMixin
|
||||
|
||||
__all__ = ['LoginView']
|
||||
|
||||
|
||||
class CASLoginView(LoginView):
|
||||
class CASLoginView(LoginView, FlashMessageMixin):
|
||||
def get(self, request):
|
||||
try:
|
||||
resp = super().get(request)
|
||||
return resp
|
||||
except PermissionDenied:
|
||||
return HttpResponseRedirect('/')
|
||||
except CASUserDoesNotExist as e:
|
||||
message_data = {
|
||||
'title': _('User does not exist: {}').format(e),
|
||||
'error': _(
|
||||
'CAS login was successful, but no corresponding local user was found in the system, and automatic '
|
||||
'user creation is disabled in the CAS authentication configuration. Login failed.'),
|
||||
'interval': 10,
|
||||
'redirect_url': '/',
|
||||
}
|
||||
return FlashMessageUtil.gen_and_redirect_to(message_data)
|
||||
|
||||
|
||||
class CASCallbackClientView(BaseAuthCallbackClientView):
|
||||
pass
|
||||
resp = HttpResponseRedirect('/')
|
||||
error_message = getattr(request, 'error_message', '')
|
||||
if error_message:
|
||||
response = self.get_failed_response('/', title=_('CAS Error'), msg=error_message)
|
||||
return response
|
||||
else:
|
||||
return resp
|
||||
|
||||
@@ -69,6 +69,8 @@ class AccessTokenAuthentication(authentication.BaseAuthentication):
|
||||
msg = _('Invalid token header. Sign string should not contain invalid characters.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
user, header = self.authenticate_credentials(token)
|
||||
if not user:
|
||||
return None
|
||||
after_authenticate_update_date(user)
|
||||
return user, header
|
||||
|
||||
@@ -77,10 +79,6 @@ class AccessTokenAuthentication(authentication.BaseAuthentication):
|
||||
model = get_user_model()
|
||||
user_id = cache.get(token)
|
||||
user = get_object_or_none(model, id=user_id)
|
||||
|
||||
if not user:
|
||||
msg = _('Invalid token or cache refreshed.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
return user, None
|
||||
|
||||
def authenticate_header(self, request):
|
||||
|
||||
@@ -7,6 +7,5 @@ from . import views
|
||||
urlpatterns = [
|
||||
path('login/', views.OAuth2AuthRequestView.as_view(), name='login'),
|
||||
path('callback/', views.OAuth2AuthCallbackView.as_view(), name='login-callback'),
|
||||
path('callback/client/', views.OAuth2AuthCallbackClientView.as_view(), name='login-callback-client'),
|
||||
path('logout/', views.OAuth2EndSessionView.as_view(), name='logout')
|
||||
]
|
||||
|
||||
@@ -6,27 +6,34 @@ from django.utils.http import urlencode
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.views import View
|
||||
|
||||
from authentication.backends.base import BaseAuthCallbackClientView
|
||||
from authentication.decorators import pre_save_next_to_session, redirect_to_pre_save_next_after_auth
|
||||
from authentication.mixins import authenticate
|
||||
from authentication.utils import build_absolute_uri
|
||||
from authentication.views.mixins import FlashMessageMixin
|
||||
from common.utils import get_logger
|
||||
from common.utils import get_logger, safe_next_url
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class OAuth2AuthRequestView(View):
|
||||
|
||||
@pre_save_next_to_session()
|
||||
def get(self, request):
|
||||
log_prompt = "Process OAuth2 GET requests: {}"
|
||||
logger.debug(log_prompt.format('Start'))
|
||||
|
||||
request_params = request.GET.dict()
|
||||
request_params.pop('next', None)
|
||||
query = urlencode(request_params)
|
||||
redirect_uri = build_absolute_uri(
|
||||
request, path=reverse(settings.AUTH_OAUTH2_AUTH_LOGIN_CALLBACK_URL_NAME)
|
||||
)
|
||||
redirect_uri = f"{redirect_uri}?{query}"
|
||||
|
||||
query_dict = {
|
||||
'client_id': settings.AUTH_OAUTH2_CLIENT_ID, 'response_type': 'code',
|
||||
'scope': settings.AUTH_OAUTH2_SCOPE,
|
||||
'redirect_uri': build_absolute_uri(
|
||||
request, path=reverse(settings.AUTH_OAUTH2_AUTH_LOGIN_CALLBACK_URL_NAME)
|
||||
)
|
||||
'redirect_uri': redirect_uri
|
||||
}
|
||||
|
||||
if '?' in settings.AUTH_OAUTH2_PROVIDER_AUTHORIZATION_ENDPOINT:
|
||||
@@ -45,6 +52,7 @@ class OAuth2AuthRequestView(View):
|
||||
class OAuth2AuthCallbackView(View, FlashMessageMixin):
|
||||
http_method_names = ['get', ]
|
||||
|
||||
@redirect_to_pre_save_next_after_auth
|
||||
def get(self, request):
|
||||
""" Processes GET requests. """
|
||||
log_prompt = "Process GET requests [OAuth2AuthCallbackView]: {}"
|
||||
@@ -59,9 +67,7 @@ class OAuth2AuthCallbackView(View, FlashMessageMixin):
|
||||
logger.debug(log_prompt.format('Login: {}'.format(user)))
|
||||
auth.login(self.request, user)
|
||||
logger.debug(log_prompt.format('Redirect'))
|
||||
return HttpResponseRedirect(
|
||||
settings.AUTH_OAUTH2_AUTHENTICATION_REDIRECT_URI
|
||||
)
|
||||
return HttpResponseRedirect(settings.AUTH_OAUTH2_AUTHENTICATION_REDIRECT_URI)
|
||||
else:
|
||||
if getattr(request, 'error_message', ''):
|
||||
response = self.get_failed_response('/', title=_('OAuth2 Error'), msg=request.error_message)
|
||||
@@ -72,10 +78,6 @@ class OAuth2AuthCallbackView(View, FlashMessageMixin):
|
||||
return HttpResponseRedirect(redirect_url)
|
||||
|
||||
|
||||
class OAuth2AuthCallbackClientView(BaseAuthCallbackClientView):
|
||||
pass
|
||||
|
||||
|
||||
class OAuth2EndSessionView(View):
|
||||
http_method_names = ['get', 'post', ]
|
||||
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
from django.db.models.signals import post_delete
|
||||
from django.dispatch import receiver
|
||||
from django.core.cache import cache
|
||||
from django.conf import settings
|
||||
|
||||
from oauth2_provider.models import get_application_model
|
||||
|
||||
from .utils import clear_oauth2_authorization_server_view_cache
|
||||
|
||||
__all__ = ['on_oauth2_provider_application_deleted']
|
||||
|
||||
|
||||
Application = get_application_model()
|
||||
|
||||
|
||||
@receiver(post_delete, sender=Application)
|
||||
def on_oauth2_provider_application_deleted(sender, instance, **kwargs):
|
||||
if instance.name == settings.OAUTH2_PROVIDER_JUMPSERVER_CLIENT_NAME:
|
||||
clear_oauth2_authorization_server_view_cache()
|
||||
|
||||
14
apps/authentication/backends/oauth2_provider/urls.py
Normal file
14
apps/authentication/backends/oauth2_provider/urls.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from django.urls import path
|
||||
|
||||
from oauth2_provider import views as op_views
|
||||
from . import views
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
path("authorize/", op_views.AuthorizationView.as_view(), name="authorize"),
|
||||
path("token/", op_views.TokenView.as_view(), name="token"),
|
||||
path("revoke/", op_views.RevokeTokenView.as_view(), name="revoke-token"),
|
||||
path(".well-known/oauth-authorization-server", views.OAuthAuthorizationServerView.as_view(), name="oauth-authorization-server"),
|
||||
]
|
||||
31
apps/authentication/backends/oauth2_provider/utils.py
Normal file
31
apps/authentication/backends/oauth2_provider/utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
from oauth2_provider.models import get_application_model
|
||||
|
||||
from common.utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def get_or_create_jumpserver_client_application():
|
||||
"""Auto get or create OAuth2 JumpServer Client application."""
|
||||
Application = get_application_model()
|
||||
|
||||
application, created = Application.objects.get_or_create(
|
||||
name=settings.OAUTH2_PROVIDER_JUMPSERVER_CLIENT_NAME,
|
||||
defaults={
|
||||
'client_type': Application.CLIENT_PUBLIC,
|
||||
'authorization_grant_type': Application.GRANT_AUTHORIZATION_CODE,
|
||||
'redirect_uris': settings.OAUTH2_PROVIDER_CLIENT_REDIRECT_URI,
|
||||
'skip_authorization': True,
|
||||
}
|
||||
)
|
||||
return application
|
||||
|
||||
|
||||
CACHE_OAUTH_SERVER_VIEW_KEY_PREFIX = 'oauth2_provider_metadata'
|
||||
|
||||
|
||||
def clear_oauth2_authorization_server_view_cache():
|
||||
logger.info("Clearing OAuth2 Authorization Server Metadata view cache")
|
||||
cache_key = f'views.decorators.cache.cache_page.{CACHE_OAUTH_SERVER_VIEW_KEY_PREFIX}.GET*'
|
||||
cache.delete_pattern(cache_key)
|
||||
77
apps/authentication/backends/oauth2_provider/views.py
Normal file
77
apps/authentication/backends/oauth2_provider/views.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from django.views.generic import View
|
||||
from django.http import JsonResponse
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.cache import cache_page
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from django.conf import settings
|
||||
from django.urls import reverse
|
||||
from oauth2_provider.settings import oauth2_settings
|
||||
from typing import List, Dict, Any
|
||||
from .utils import get_or_create_jumpserver_client_application, CACHE_OAUTH_SERVER_VIEW_KEY_PREFIX
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name='dispatch')
|
||||
@method_decorator(cache_page(timeout=60 * 60 * 24, key_prefix=CACHE_OAUTH_SERVER_VIEW_KEY_PREFIX), name='dispatch')
|
||||
class OAuthAuthorizationServerView(View):
|
||||
"""
|
||||
OAuth 2.0 Authorization Server Metadata Endpoint
|
||||
RFC 8414: https://datatracker.ietf.org/doc/html/rfc8414
|
||||
|
||||
This endpoint provides machine-readable information about the
|
||||
OAuth 2.0 authorization server's configuration.
|
||||
"""
|
||||
|
||||
def get_base_url(self, request) -> str:
|
||||
scheme = 'https' if request.is_secure() else 'http'
|
||||
host = request.get_host()
|
||||
return f"{scheme}://{host}"
|
||||
|
||||
def get_supported_scopes(self) -> List[str]:
|
||||
scopes_config = oauth2_settings.SCOPES
|
||||
if isinstance(scopes_config, dict):
|
||||
return list(scopes_config.keys())
|
||||
return []
|
||||
|
||||
def get_metadata(self, request) -> Dict[str, Any]:
|
||||
base_url = self.get_base_url(request)
|
||||
application = get_or_create_jumpserver_client_application()
|
||||
metadata = {
|
||||
"issuer": base_url,
|
||||
"client_id": application.client_id if application else "Not found any application.",
|
||||
"authorization_endpoint": base_url + reverse('authentication:oauth2-provider:authorize'),
|
||||
"token_endpoint": base_url + reverse('authentication:oauth2-provider:token'),
|
||||
"revocation_endpoint": base_url + reverse('authentication:oauth2-provider:revoke-token'),
|
||||
|
||||
"response_types_supported": ["code"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"scopes_supported": self.get_supported_scopes(),
|
||||
|
||||
"token_endpoint_auth_methods_supported": ["none"],
|
||||
"revocation_endpoint_auth_methods_supported": ["none"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"response_modes_supported": ["query"],
|
||||
}
|
||||
if hasattr(oauth2_settings, 'ACCESS_TOKEN_EXPIRE_SECONDS'):
|
||||
metadata["token_expires_in"] = oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS
|
||||
if hasattr(oauth2_settings, 'REFRESH_TOKEN_EXPIRE_SECONDS'):
|
||||
if oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS:
|
||||
metadata["refresh_token_expires_in"] = oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS
|
||||
return metadata
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
metadata = self.get_metadata(request)
|
||||
response = JsonResponse(metadata)
|
||||
self.add_cors_headers(response)
|
||||
return response
|
||||
|
||||
def options(self, request, *args, **kwargs):
|
||||
response = JsonResponse({})
|
||||
self.add_cors_headers(response)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def add_cors_headers(response):
|
||||
response['Access-Control-Allow-Origin'] = '*'
|
||||
response['Access-Control-Allow-Methods'] = 'GET, OPTIONS'
|
||||
response['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
|
||||
response['Access-Control-Max-Age'] = '3600'
|
||||
@@ -15,6 +15,5 @@ from . import views
|
||||
urlpatterns = [
|
||||
path('login/', views.OIDCAuthRequestView.as_view(), name='login'),
|
||||
path('callback/', views.OIDCAuthCallbackView.as_view(), name='login-callback'),
|
||||
path('callback/client/', views.OIDCAuthCallbackClientView.as_view(), name='login-callback-client'),
|
||||
path('logout/', views.OIDCEndSessionView.as_view(), name='logout'),
|
||||
]
|
||||
|
||||
@@ -25,11 +25,11 @@ from django.utils.http import urlencode
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.views.generic import View
|
||||
|
||||
from authentication.decorators import pre_save_next_to_session, redirect_to_pre_save_next_after_auth
|
||||
from authentication.utils import build_absolute_uri_for_oidc
|
||||
from authentication.views.mixins import FlashMessageMixin
|
||||
from common.utils import safe_next_url
|
||||
from .utils import get_logger
|
||||
from ..base import BaseAuthCallbackClientView
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
@@ -58,6 +58,7 @@ class OIDCAuthRequestView(View):
|
||||
b = base64.urlsafe_b64encode(h)
|
||||
return b.decode('ascii')[:-1]
|
||||
|
||||
@pre_save_next_to_session()
|
||||
def get(self, request):
|
||||
""" Processes GET requests. """
|
||||
|
||||
@@ -66,8 +67,9 @@ class OIDCAuthRequestView(View):
|
||||
|
||||
# Defines common parameters used to bootstrap the authentication request.
|
||||
logger.debug(log_prompt.format('Construct request params'))
|
||||
authentication_request_params = request.GET.dict()
|
||||
authentication_request_params.update({
|
||||
request_params = request.GET.dict()
|
||||
request_params.pop('next', None)
|
||||
request_params.update({
|
||||
'scope': settings.AUTH_OPENID_SCOPES,
|
||||
'response_type': 'code',
|
||||
'client_id': settings.AUTH_OPENID_CLIENT_ID,
|
||||
@@ -80,7 +82,7 @@ class OIDCAuthRequestView(View):
|
||||
code_verifier = self.gen_code_verifier()
|
||||
code_challenge_method = settings.AUTH_OPENID_CODE_CHALLENGE_METHOD or 'S256'
|
||||
code_challenge = self.gen_code_challenge(code_verifier, code_challenge_method)
|
||||
authentication_request_params.update({
|
||||
request_params.update({
|
||||
'code_challenge_method': code_challenge_method,
|
||||
'code_challenge': code_challenge
|
||||
})
|
||||
@@ -91,7 +93,7 @@ class OIDCAuthRequestView(View):
|
||||
if settings.AUTH_OPENID_USE_STATE:
|
||||
logger.debug(log_prompt.format('Use state'))
|
||||
state = get_random_string(settings.AUTH_OPENID_STATE_LENGTH)
|
||||
authentication_request_params.update({'state': state})
|
||||
request_params.update({'state': state})
|
||||
request.session['oidc_auth_state'] = state
|
||||
|
||||
# Nonces should be used too! In that case the generated nonce is stored both in the
|
||||
@@ -99,17 +101,12 @@ class OIDCAuthRequestView(View):
|
||||
if settings.AUTH_OPENID_USE_NONCE:
|
||||
logger.debug(log_prompt.format('Use nonce'))
|
||||
nonce = get_random_string(settings.AUTH_OPENID_NONCE_LENGTH)
|
||||
authentication_request_params.update({'nonce': nonce, })
|
||||
request_params.update({'nonce': nonce, })
|
||||
request.session['oidc_auth_nonce'] = nonce
|
||||
|
||||
# Stores the "next" URL in the session if applicable.
|
||||
logger.debug(log_prompt.format('Stores next url in the session'))
|
||||
next_url = request.GET.get('next')
|
||||
request.session['oidc_auth_next_url'] = safe_next_url(next_url, request=request)
|
||||
|
||||
# Redirects the user to authorization endpoint.
|
||||
logger.debug(log_prompt.format('Construct redirect url'))
|
||||
query = urlencode(authentication_request_params)
|
||||
query = urlencode(request_params)
|
||||
redirect_url = '{url}?{query}'.format(
|
||||
url=settings.AUTH_OPENID_PROVIDER_AUTHORIZATION_ENDPOINT, query=query)
|
||||
|
||||
@@ -129,6 +126,8 @@ class OIDCAuthCallbackView(View, FlashMessageMixin):
|
||||
|
||||
http_method_names = ['get', ]
|
||||
|
||||
|
||||
@redirect_to_pre_save_next_after_auth
|
||||
def get(self, request):
|
||||
""" Processes GET requests. """
|
||||
log_prompt = "Process GET requests [OIDCAuthCallbackView]: {}"
|
||||
@@ -167,7 +166,6 @@ class OIDCAuthCallbackView(View, FlashMessageMixin):
|
||||
raise SuspiciousOperation('Invalid OpenID Connect callback state value')
|
||||
|
||||
# Authenticates the end-user.
|
||||
next_url = request.session.get('oidc_auth_next_url', None)
|
||||
code_verifier = request.session.get('oidc_auth_code_verifier', None)
|
||||
logger.debug(log_prompt.format('Process authenticate'))
|
||||
try:
|
||||
@@ -191,9 +189,7 @@ class OIDCAuthCallbackView(View, FlashMessageMixin):
|
||||
callback_params.get('session_state', None)
|
||||
|
||||
logger.debug(log_prompt.format('Redirect'))
|
||||
return HttpResponseRedirect(
|
||||
next_url or settings.AUTH_OPENID_AUTHENTICATION_REDIRECT_URI
|
||||
)
|
||||
return HttpResponseRedirect(settings.AUTH_OPENID_AUTHENTICATION_REDIRECT_URI)
|
||||
if 'error' in callback_params:
|
||||
logger.debug(
|
||||
log_prompt.format('Error in callback params: {}'.format(callback_params['error']))
|
||||
@@ -212,10 +208,6 @@ class OIDCAuthCallbackView(View, FlashMessageMixin):
|
||||
return HttpResponseRedirect(redirect_url)
|
||||
|
||||
|
||||
class OIDCAuthCallbackClientView(BaseAuthCallbackClientView):
|
||||
pass
|
||||
|
||||
|
||||
class OIDCEndSessionView(View):
|
||||
""" Allows to end the session of any user authenticated using OpenID Connect.
|
||||
|
||||
|
||||
@@ -8,6 +8,5 @@ urlpatterns = [
|
||||
path('login/', views.Saml2AuthRequestView.as_view(), name='saml2-login'),
|
||||
path('logout/', views.Saml2EndSessionView.as_view(), name='saml2-logout'),
|
||||
path('callback/', views.Saml2AuthCallbackView.as_view(), name='saml2-callback'),
|
||||
path('callback/client/', views.Saml2AuthCallbackClientView.as_view(), name='saml2-callback-client'),
|
||||
path('metadata/', views.Saml2AuthMetadataView.as_view(), name='saml2-metadata'),
|
||||
]
|
||||
|
||||
@@ -17,9 +17,8 @@ from onelogin.saml2.idp_metadata_parser import (
|
||||
)
|
||||
|
||||
from authentication.views.mixins import FlashMessageMixin
|
||||
from common.utils import get_logger
|
||||
from common.utils import get_logger, safe_next_url
|
||||
from .settings import JmsSaml2Settings
|
||||
from ..base import BaseAuthCallbackClientView
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
@@ -208,13 +207,16 @@ class Saml2AuthRequestView(View, PrepareRequestMixin):
|
||||
log_prompt = "Process SAML GET requests: {}"
|
||||
logger.debug(log_prompt.format('Start'))
|
||||
|
||||
request_params = request.GET.dict()
|
||||
|
||||
try:
|
||||
saml_instance = self.init_saml_auth(request)
|
||||
except OneLogin_Saml2_Error as error:
|
||||
logger.error(log_prompt.format('Init saml auth error: %s' % error))
|
||||
return HttpResponse(error, status=412)
|
||||
|
||||
next_url = settings.AUTH_SAML2_PROVIDER_AUTHORIZATION_ENDPOINT
|
||||
next_url = request_params.get('next') or settings.AUTH_SAML2_PROVIDER_AUTHORIZATION_ENDPOINT
|
||||
next_url = safe_next_url(next_url, request=request)
|
||||
url = saml_instance.login(return_to=next_url)
|
||||
logger.debug(log_prompt.format('Redirect login url'))
|
||||
return HttpResponseRedirect(url)
|
||||
@@ -293,10 +295,11 @@ class Saml2AuthCallbackView(View, PrepareRequestMixin, FlashMessageMixin):
|
||||
return response
|
||||
|
||||
logger.debug(log_prompt.format('Redirect'))
|
||||
redir = post_data.get('RelayState')
|
||||
if not redir or len(redir) == 0:
|
||||
redir = "/"
|
||||
next_url = saml_instance.redirect_to(redir)
|
||||
relay_state = post_data.get('RelayState')
|
||||
if not relay_state or len(relay_state) == 0:
|
||||
relay_state = "/"
|
||||
next_url = saml_instance.redirect_to(relay_state)
|
||||
next_url = safe_next_url(next_url, request=request)
|
||||
return HttpResponseRedirect(next_url)
|
||||
|
||||
@csrf_exempt
|
||||
@@ -304,10 +307,6 @@ class Saml2AuthCallbackView(View, PrepareRequestMixin, FlashMessageMixin):
|
||||
return super().dispatch(*args, **kwargs)
|
||||
|
||||
|
||||
class Saml2AuthCallbackClientView(BaseAuthCallbackClientView):
|
||||
pass
|
||||
|
||||
|
||||
class Saml2AuthMetadataView(View, PrepareRequestMixin):
|
||||
|
||||
def get(self, request):
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from django.db.models import TextChoices
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
USER_LOGIN_GUARD_VIEW_REDIRECT_FIELD = 'next'
|
||||
|
||||
RSA_PRIVATE_KEY = 'rsa_private_key'
|
||||
RSA_PUBLIC_KEY = 'rsa_public_key'
|
||||
|
||||
|
||||
193
apps/authentication/decorators.py
Normal file
193
apps/authentication/decorators.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
This module provides decorators to handle redirect URLs during the authentication flow:
|
||||
1. pre_save_next_to_session: Captures and stores the intended next URL before redirecting to auth provider
|
||||
2. redirect_to_pre_save_next_after_auth: Redirects to the stored next URL after successful authentication
|
||||
3. post_save_next_to_session: Copies the stored next URL to session['next'] after view execution
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from django.http import HttpResponseRedirect
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from functools import wraps
|
||||
|
||||
from common.utils import get_logger, safe_next_url
|
||||
from .const import USER_LOGIN_GUARD_VIEW_REDIRECT_FIELD
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'pre_save_next_to_session', 'redirect_to_pre_save_next_after_auth',
|
||||
'post_save_next_to_session_if_guard_redirect'
|
||||
]
|
||||
|
||||
# Session key for storing the redirect URL after authentication
|
||||
AUTH_SESSION_NEXT_URL_KEY = 'auth_next_url'
|
||||
|
||||
|
||||
def pre_save_next_to_session(get_next_url=None):
|
||||
"""
|
||||
Decorator to capture and store the 'next' parameter into session BEFORE view execution.
|
||||
|
||||
This decorator is applied to the authentication request view to preserve the user's
|
||||
intended destination URL before redirecting to the authentication provider.
|
||||
|
||||
Args:
|
||||
get_next_url: Optional callable that extracts the next URL from request.
|
||||
Default: lambda req: req.GET.get('next')
|
||||
|
||||
Usage:
|
||||
# Use default (request.GET.get('next'))
|
||||
@pre_save_next_to_session()
|
||||
def get(self, request):
|
||||
pass
|
||||
|
||||
# Custom extraction from POST data
|
||||
@pre_save_next_to_session(get_next_url=lambda req: req.POST.get('next'))
|
||||
def post(self, request):
|
||||
pass
|
||||
|
||||
# Custom extraction from both GET and POST
|
||||
@pre_save_next_to_session(
|
||||
get_next_url=lambda req: req.GET.get('next') or req.POST.get('next')
|
||||
)
|
||||
def get(self, request):
|
||||
pass
|
||||
|
||||
Example flow:
|
||||
User accesses: /auth/oauth2/?next=/dashboard/
|
||||
↓ (decorator saves '/dashboard/' to session)
|
||||
Redirected to OAuth2 provider for authentication
|
||||
"""
|
||||
# Default function to extract next URL from request.GET
|
||||
if get_next_url is None:
|
||||
get_next_url = lambda req: req.GET.get('next')
|
||||
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def wrapper(self, request, *args, **kwargs):
|
||||
next_url = get_next_url(request)
|
||||
if next_url:
|
||||
request.session[AUTH_SESSION_NEXT_URL_KEY] = next_url
|
||||
logger.debug(f"[Auth] Saved next_url to session: {next_url}")
|
||||
return view_func(self, request, *args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def redirect_to_pre_save_next_after_auth(view_func):
|
||||
"""
|
||||
Decorator to redirect to the previously saved 'next' URL after successful authentication.
|
||||
|
||||
This decorator is applied to the authentication callback view. After the user successfully
|
||||
authenticates, if a 'next' URL was previously saved in the session (by pre_save_next_to_session),
|
||||
the user will be redirected to that URL instead of the default redirect location.
|
||||
|
||||
Conditions for redirect:
|
||||
- User must be authenticated (request.user.is_authenticated)
|
||||
- Session must contain the saved next URL (AUTH_SESSION_NEXT_URL_KEY)
|
||||
- The next URL must not be '/' (avoid unnecessary redirects)
|
||||
- The next URL must pass security validation (safe_next_url)
|
||||
|
||||
If any condition fails, returns the original view response.
|
||||
|
||||
Usage:
|
||||
@redirect_to_pre_save_next_after_auth
|
||||
def get(self, request):
|
||||
# Process authentication callback
|
||||
if user_authenticated:
|
||||
auth.login(request, user)
|
||||
return HttpResponseRedirect(default_url)
|
||||
|
||||
Example flow:
|
||||
User redirected back from OAuth2 provider: /auth/oauth2/callback/?code=xxx
|
||||
↓ (view processes authentication, user becomes authenticated)
|
||||
Decorator checks session for saved next URL
|
||||
↓ (finds '/dashboard/' in session)
|
||||
Redirects to: /dashboard/
|
||||
↓ (clears saved URL from session)
|
||||
"""
|
||||
@wraps(view_func)
|
||||
def wrapper(self, request, *args, **kwargs):
|
||||
# Execute the original view method first
|
||||
response = view_func(self, request, *args, **kwargs)
|
||||
|
||||
# Check if user has been authenticated
|
||||
if request.user and request.user.is_authenticated:
|
||||
# Check if session contains a saved next URL
|
||||
saved_next_url = request.session.get(AUTH_SESSION_NEXT_URL_KEY)
|
||||
|
||||
if saved_next_url and saved_next_url != '/':
|
||||
# Validate the URL for security
|
||||
safe_url = safe_next_url(saved_next_url, request=request)
|
||||
if safe_url:
|
||||
# Clear the saved URL from session (one-time use)
|
||||
request.session.pop(AUTH_SESSION_NEXT_URL_KEY, None)
|
||||
logger.debug(f"[Auth] Redirecting authenticated user to saved next_url: {safe_url}")
|
||||
return HttpResponseRedirect(safe_url)
|
||||
|
||||
# Return the original response if no redirect conditions are met
|
||||
return response
|
||||
return wrapper
|
||||
|
||||
|
||||
def post_save_next_to_session_if_guard_redirect(view_func):
|
||||
"""
|
||||
Decorator to copy AUTH_SESSION_NEXT_URL_KEY to session['next'] after view execution,
|
||||
but only if redirecting to login-guard view.
|
||||
|
||||
This decorator is applied AFTER view execution. It copies the value from
|
||||
AUTH_SESSION_NEXT_URL_KEY (internal storage) to 'next' (standard session key)
|
||||
for use by downstream code.
|
||||
|
||||
Only sets the 'next' session key when the response is a redirect to guard-view
|
||||
(i.e., response with redirect status code and location path matching login-guard view URL).
|
||||
|
||||
Usage:
|
||||
@post_save_next_to_session_if_guard_redirect
|
||||
def get(self, request):
|
||||
# Process the request and return response
|
||||
if some_condition:
|
||||
return self.redirect_to_guard_view() # Decorator will copy next to session
|
||||
return HttpResponseRedirect(url) # Decorator won't copy if not to guard-view
|
||||
|
||||
Example flow:
|
||||
View executes and returns redirect to guard view
|
||||
↓ (response is redirect with 'login-guard' in Location)
|
||||
Decorator checks if response is redirect to guard-view and session has saved next URL
|
||||
↓ (copies AUTH_SESSION_NEXT_URL_KEY to session['next'])
|
||||
User is redirected to guard-view with 'next' available in session
|
||||
"""
|
||||
@wraps(view_func)
|
||||
def wrapper(self, request, *args, **kwargs):
|
||||
# Execute the original view method
|
||||
response = view_func(self, request, *args, **kwargs)
|
||||
|
||||
# Check if response is a redirect to guard view
|
||||
# Redirect responses typically have status codes 301, 302, 303, 307, 308
|
||||
is_guard_redirect = False
|
||||
if hasattr(response, 'status_code') and response.status_code in (301, 302, 303, 307, 308):
|
||||
# Check if the redirect location is to guard view
|
||||
location = response.get('Location', '')
|
||||
if location:
|
||||
# Extract path from location URL (handle both absolute and relative URLs)
|
||||
parsed_url = urlparse(location)
|
||||
path = parsed_url.path
|
||||
|
||||
# Check if path matches guard view URL pattern
|
||||
guard_view_url = reverse('authentication:login-guard')
|
||||
if path == guard_view_url:
|
||||
is_guard_redirect = True
|
||||
|
||||
# Only set 'next' if response is a redirect to guard view
|
||||
if is_guard_redirect:
|
||||
# Copy AUTH_SESSION_NEXT_URL_KEY to 'next' if it exists
|
||||
saved_next_url = request.session.get(AUTH_SESSION_NEXT_URL_KEY)
|
||||
if saved_next_url:
|
||||
# 这里 'next' 是 UserLoginGuardView.redirect_field_name
|
||||
request.session[USER_LOGIN_GUARD_VIEW_REDIRECT_FIELD] = saved_next_url
|
||||
logger.debug(f"[Auth] Copied {AUTH_SESSION_NEXT_URL_KEY} to 'next' in session: {saved_next_url}")
|
||||
|
||||
return response
|
||||
return wrapper
|
||||
@@ -50,7 +50,7 @@ class UserLoginForm(forms.Form):
|
||||
|
||||
class UserCheckOtpCodeForm(forms.Form):
|
||||
code = forms.CharField(label=_('MFA Code'), max_length=128, required=False)
|
||||
mfa_type = forms.CharField(label=_('MFA type'), max_length=128, required=False)
|
||||
mfa_type = forms.CharField(label=_('MFA type'), max_length=128)
|
||||
|
||||
|
||||
class CustomCaptchaTextInput(CaptchaTextInput):
|
||||
|
||||
0
apps/authentication/management/__init__.py
Normal file
0
apps/authentication/management/__init__.py
Normal file
0
apps/authentication/management/commands/__init__.py
Normal file
0
apps/authentication/management/commands/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db.utils import OperationalError, ProgrammingError
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = 'Initialize OAuth2 Provider - Create default JumpServer Client application'
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
'--force',
|
||||
action='store_true',
|
||||
help='Force recreate the application even if it exists',
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
force = options.get('force', False)
|
||||
|
||||
try:
|
||||
from authentication.backends.oauth2_provider.utils import (
|
||||
get_or_create_jumpserver_client_application
|
||||
)
|
||||
from oauth2_provider.models import get_application_model
|
||||
|
||||
Application = get_application_model()
|
||||
|
||||
# 检查表是否存在
|
||||
try:
|
||||
Application.objects.exists()
|
||||
except (OperationalError, ProgrammingError) as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(
|
||||
f'OAuth2 Provider tables not found. Please run migrations first:\n'
|
||||
f' python manage.py migrate oauth2_provider\n'
|
||||
f'Error: {e}'
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# 如果强制重建,先删除已存在的应用
|
||||
if force:
|
||||
deleted_count, _ = Application.objects.filter(
|
||||
name=settings.OAUTH2_PROVIDER_JUMPSERVER_CLIENT_NAME
|
||||
).delete()
|
||||
if deleted_count > 0:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(f'Deleted {deleted_count} existing application(s)')
|
||||
)
|
||||
|
||||
# 创建或获取应用
|
||||
application = get_or_create_jumpserver_client_application()
|
||||
|
||||
if application:
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f'✓ OAuth2 JumpServer Client application initialized successfully\n'
|
||||
f' - Client ID: {application.client_id}\n'
|
||||
f' - Client Type: {application.get_client_type_display()}\n'
|
||||
f' - Grant Type: {application.get_authorization_grant_type_display()}\n'
|
||||
f' - Redirect URIs: {application.redirect_uris}\n'
|
||||
f' - Skip Authorization: {application.skip_authorization}'
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.stdout.write(
|
||||
self.style.ERROR('Failed to create OAuth2 application')
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(f'Error initializing OAuth2 Provider: {e}')
|
||||
)
|
||||
raise
|
||||
@@ -72,9 +72,10 @@ class BaseMFA(abc.ABC):
|
||||
def is_active(self):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def global_enabled(cls):
|
||||
return cls.name in settings.SECURITY_MFA_ENABLED_BACKENDS
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def global_enabled():
|
||||
return False
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_enable_url(self) -> str:
|
||||
|
||||
@@ -39,9 +39,9 @@ class MFACustom(BaseMFA):
|
||||
def is_active(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def global_enabled(cls):
|
||||
return super().global_enabled() and settings.MFA_CUSTOM and callable(mfa_custom_method)
|
||||
@staticmethod
|
||||
def global_enabled():
|
||||
return settings.MFA_CUSTOM and callable(mfa_custom_method)
|
||||
|
||||
def get_enable_url(self) -> str:
|
||||
return ''
|
||||
|
||||
@@ -50,9 +50,9 @@ class MFAEmail(BaseMFA):
|
||||
)
|
||||
sender_util.gen_and_send_async()
|
||||
|
||||
@classmethod
|
||||
def global_enabled(cls):
|
||||
return super().global_enabled and settings.SECURITY_MFA_BY_EMAIL
|
||||
@staticmethod
|
||||
def global_enabled():
|
||||
return settings.SECURITY_MFA_BY_EMAIL
|
||||
|
||||
def disable(self):
|
||||
return '/ui/#/profile/index'
|
||||
|
||||
@@ -29,10 +29,9 @@ class MFAFace(BaseMFA, AuthFaceMixin):
|
||||
return True
|
||||
return bool(self.user.face_vector)
|
||||
|
||||
@classmethod
|
||||
def global_enabled(cls):
|
||||
@staticmethod
|
||||
def global_enabled():
|
||||
return (
|
||||
super().global_enabled() and
|
||||
settings.XPACK_LICENSE_IS_VALID and
|
||||
settings.XPACK_LICENSE_EDITION_ULTIMATE and
|
||||
settings.FACE_RECOGNITION_ENABLED
|
||||
|
||||
@@ -25,6 +25,10 @@ class MFAOtp(BaseMFA):
|
||||
return True
|
||||
return self.user.otp_secret_key
|
||||
|
||||
@staticmethod
|
||||
def global_enabled():
|
||||
return True
|
||||
|
||||
def get_enable_url(self) -> str:
|
||||
return reverse('authentication:user-otp-enable-start')
|
||||
|
||||
|
||||
@@ -23,9 +23,9 @@ class MFAPasskey(BaseMFA):
|
||||
return False
|
||||
return self.user.passkey_set.count()
|
||||
|
||||
@classmethod
|
||||
def global_enabled(cls):
|
||||
return super().global_enabled() and settings.AUTH_PASSKEY
|
||||
@staticmethod
|
||||
def global_enabled():
|
||||
return settings.AUTH_PASSKEY
|
||||
|
||||
def get_enable_url(self) -> str:
|
||||
return '/ui/#/profile/passkeys'
|
||||
|
||||
@@ -27,9 +27,9 @@ class MFARadius(BaseMFA):
|
||||
def is_active(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def global_enabled(cls):
|
||||
return super().global_enabled() and settings.OTP_IN_RADIUS
|
||||
@staticmethod
|
||||
def global_enabled():
|
||||
return settings.OTP_IN_RADIUS
|
||||
|
||||
def get_enable_url(self) -> str:
|
||||
return ''
|
||||
|
||||
@@ -46,9 +46,9 @@ class MFASms(BaseMFA):
|
||||
def send_challenge(self):
|
||||
self.sms.gen_and_send_async()
|
||||
|
||||
@classmethod
|
||||
def global_enabled(cls):
|
||||
return super().global_enabled() and settings.SMS_ENABLED
|
||||
@staticmethod
|
||||
def global_enabled():
|
||||
return settings.SMS_ENABLED
|
||||
|
||||
def get_enable_url(self) -> str:
|
||||
return '/ui/#/profile/index'
|
||||
|
||||
@@ -6,6 +6,7 @@ import time
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
from werkzeug.local import Local
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib import auth
|
||||
@@ -16,6 +17,7 @@ from django.contrib.auth import (
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.cache import cache
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db.models import Q
|
||||
from django.shortcuts import reverse, redirect, get_object_or_404
|
||||
from django.utils.http import urlencode
|
||||
from django.utils.translation import gettext as _
|
||||
@@ -31,6 +33,87 @@ from .signals import post_auth_success, post_auth_failed
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 模块级别的线程上下文,用于 authenticate 函数中标记当前线程
|
||||
_auth_thread_context = Local()
|
||||
|
||||
# 保存 Django 原始的 get_or_create 方法(在模块加载时保存一次)
|
||||
def _save_original_get_or_create():
|
||||
"""保存 Django 原始的 get_or_create 方法"""
|
||||
from django.contrib.auth import get_user_model as get_user_model_func
|
||||
UserModel = get_user_model_func()
|
||||
return UserModel.objects.get_or_create
|
||||
|
||||
_django_original_get_or_create = _save_original_get_or_create()
|
||||
|
||||
|
||||
class OnlyAllowExistUserAuthError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _authenticate_context(func):
|
||||
"""
|
||||
装饰器:管理 authenticate 函数的执行上下文
|
||||
|
||||
功能:
|
||||
1. 执行前:
|
||||
- 在线程本地存储中标记当前正在执行 authenticate
|
||||
- 临时替换 UserModel.objects.get_or_create 方法
|
||||
2. 执行后:
|
||||
- 清理线程本地存储标记
|
||||
- 恢复 get_or_create 为 Django 原始方法
|
||||
|
||||
作用:
|
||||
- 确保 get_or_create 行为仅在 authenticate 生命周期内生效
|
||||
- 支持 ONLY_ALLOW_EXIST_USER_AUTH 配置的线程安全实现
|
||||
- 防止跨请求或跨线程的状态污染
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(request=None, **credentials):
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
UserModel = get_user_model()
|
||||
|
||||
def custom_get_or_create(*args, **kwargs):
|
||||
create_username = kwargs.get('username')
|
||||
logger.debug(f"get_or_create: thread_id={threading.get_ident()}, username={create_username}")
|
||||
|
||||
# 如果当前线程正在执行 authenticate 且仅允许已存在用户认证,则提前判断用户是否存在
|
||||
if (
|
||||
getattr(_auth_thread_context, 'in_authenticate', False) and
|
||||
settings.ONLY_ALLOW_EXIST_USER_AUTH
|
||||
):
|
||||
try:
|
||||
UserModel.objects.get(username=create_username)
|
||||
except UserModel.DoesNotExist:
|
||||
raise OnlyAllowExistUserAuthError
|
||||
|
||||
# 调用 Django 原始方法(已是绑定方法,直接传参)
|
||||
return _django_original_get_or_create(*args, **kwargs)
|
||||
|
||||
|
||||
try:
|
||||
# 执行前:设置线程上下文和 monkey-patch
|
||||
setattr(_auth_thread_context, 'in_authenticate', True)
|
||||
UserModel.objects.get_or_create = custom_get_or_create
|
||||
|
||||
# 执行原函数
|
||||
return func(request, **credentials)
|
||||
finally:
|
||||
# 执行后:清理线程上下文和恢复原始方法
|
||||
try:
|
||||
if hasattr(_auth_thread_context, 'in_authenticate'):
|
||||
delattr(_auth_thread_context, 'in_authenticate')
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
UserModel.objects.get_or_create = _django_original_get_or_create
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _get_backends(return_tuples=False):
|
||||
backends = []
|
||||
@@ -48,39 +131,16 @@ def _get_backends(return_tuples=False):
|
||||
return backends
|
||||
|
||||
|
||||
class OnlyAllowExistUserAuthError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
auth._get_backends = _get_backends
|
||||
|
||||
|
||||
@_authenticate_context
|
||||
def authenticate(request=None, **credentials):
|
||||
"""
|
||||
If the given credentials are valid, return a User object.
|
||||
之所以 hack 这个 authenticate
|
||||
"""
|
||||
|
||||
UserModel = get_user_model()
|
||||
original_get_or_create = UserModel.objects.get_or_create
|
||||
|
||||
thread_local = threading.local()
|
||||
thread_local.thread_id = threading.get_ident()
|
||||
|
||||
def custom_get_or_create(self, *args, **kwargs):
|
||||
logger.debug(f"get_or_create: thread_id={threading.get_ident()}, username={username}")
|
||||
if threading.get_ident() != thread_local.thread_id or not settings.ONLY_ALLOW_EXIST_USER_AUTH:
|
||||
return original_get_or_create(*args, **kwargs)
|
||||
create_username = kwargs.get('username')
|
||||
try:
|
||||
UserModel.objects.get(username=create_username)
|
||||
except UserModel.DoesNotExist:
|
||||
raise OnlyAllowExistUserAuthError
|
||||
return original_get_or_create(*args, **kwargs)
|
||||
|
||||
username = credentials.get('username')
|
||||
|
||||
temp_user = None
|
||||
username = credentials.get('username')
|
||||
for backend, backend_path in _get_backends(return_tuples=True):
|
||||
# 检查用户名是否允许认证 (预先检查,不浪费认证时间)
|
||||
logger.info('Try using auth backend: {}'.format(str(backend)))
|
||||
@@ -94,27 +154,28 @@ def authenticate(request=None, **credentials):
|
||||
except TypeError:
|
||||
# This backend doesn't accept these credentials as arguments. Try the next one.
|
||||
continue
|
||||
|
||||
try:
|
||||
UserModel.objects.get_or_create = custom_get_or_create.__get__(UserModel.objects)
|
||||
user = backend.authenticate(request, **credentials)
|
||||
except PermissionDenied:
|
||||
# This backend says to stop in our tracks - this user should not be allowed in at all.
|
||||
break
|
||||
except OnlyAllowExistUserAuthError:
|
||||
request.error_message = _(
|
||||
'''The administrator has enabled "Only allow existing users to log in",
|
||||
and the current user is not in the user list. Please contact the administrator.'''
|
||||
)
|
||||
if request:
|
||||
request.error_message = _(
|
||||
'''The administrator has enabled "Only allow existing users to log in",
|
||||
and the current user is not in the user list. Please contact the administrator.'''
|
||||
)
|
||||
continue
|
||||
finally:
|
||||
UserModel.objects.get_or_create = original_get_or_create
|
||||
|
||||
if user is None:
|
||||
continue
|
||||
|
||||
if not user.is_valid:
|
||||
temp_user = user
|
||||
temp_user.backend = backend_path
|
||||
request.error_message = _('User is invalid')
|
||||
if request:
|
||||
request.error_message = _('User is invalid')
|
||||
return temp_user
|
||||
|
||||
# 检查用户是否允许认证
|
||||
@@ -129,8 +190,11 @@ def authenticate(request=None, **credentials):
|
||||
else:
|
||||
if temp_user is not None:
|
||||
source_display = temp_user.source_display
|
||||
request.error_message = _('''The administrator has enabled 'Only allow login from user source'.
|
||||
The current user source is {}. Please contact the administrator.''').format(source_display)
|
||||
if request:
|
||||
request.error_message = _(
|
||||
''' The administrator has enabled 'Only allow login from user source'.
|
||||
The current user source is {}. Please contact the administrator. '''
|
||||
).format(source_display)
|
||||
return temp_user
|
||||
|
||||
# The credentials supplied are invalid to all backends, fire signal
|
||||
@@ -228,7 +292,8 @@ class AuthPreCheckMixin:
|
||||
if not settings.ONLY_ALLOW_EXIST_USER_AUTH:
|
||||
return
|
||||
|
||||
exist = User.objects.filter(username=username).exists()
|
||||
q = Q(username=username) | Q(email=username)
|
||||
exist = User.objects.filter(q).exists()
|
||||
if not exist:
|
||||
logger.error(f"Only allow exist user auth, login failed: {username}")
|
||||
self.raise_credential_error(errors.reason_user_not_exist)
|
||||
|
||||
@@ -9,11 +9,12 @@ from common.utils import get_object_or_none, random_string
|
||||
from users.models import User
|
||||
from users.serializers import UserProfileSerializer
|
||||
from ..models import AccessKey, TempToken
|
||||
from oauth2_provider.models import get_access_token_model
|
||||
|
||||
__all__ = [
|
||||
'AccessKeySerializer', 'BearerTokenSerializer',
|
||||
'SSOTokenSerializer', 'TempTokenSerializer',
|
||||
'AccessKeyCreateSerializer'
|
||||
'AccessKeyCreateSerializer', 'AccessTokenSerializer',
|
||||
]
|
||||
|
||||
|
||||
@@ -114,3 +115,28 @@ class TempTokenSerializer(serializers.ModelSerializer):
|
||||
token = TempToken(**kwargs)
|
||||
token.save()
|
||||
return token
|
||||
|
||||
|
||||
class AccessTokenSerializer(serializers.ModelSerializer):
|
||||
token_preview = serializers.SerializerMethodField(label=_("Token"))
|
||||
|
||||
class Meta:
|
||||
model = get_access_token_model()
|
||||
fields = [
|
||||
'id', 'user', 'token_preview', 'is_valid',
|
||||
'is_expired', 'expires', 'scope', 'created', 'updated',
|
||||
]
|
||||
read_only_fields = fields
|
||||
extra_kwargs = {
|
||||
'scope': { 'label': _('Scope') },
|
||||
'expires': { 'label': _('Date expired') },
|
||||
'updated': { 'label': _('Date updated') },
|
||||
'created': { 'label': _('Date created') },
|
||||
}
|
||||
|
||||
|
||||
def get_token_preview(self, obj):
|
||||
token_string = obj.token
|
||||
if len(token_string) > 16:
|
||||
return f"{token_string[:6]}...{token_string[-4:]}"
|
||||
return "****"
|
||||
@@ -9,6 +9,8 @@ from audits.models import UserSession
|
||||
from common.sessions.cache import user_session_manager
|
||||
from .signals import post_auth_success, post_auth_failed, user_auth_failed, user_auth_success
|
||||
|
||||
from .backends.oauth2_provider.signal_handlers import *
|
||||
|
||||
|
||||
@receiver(user_logged_in)
|
||||
def on_user_auth_login_success(sender, user, request, **kwargs):
|
||||
@@ -57,3 +59,4 @@ def on_user_login_success(sender, request, user, backend, create=False, **kwargs
|
||||
def on_user_login_failed(sender, username, request, reason, backend, **kwargs):
|
||||
request.session['auth_backend'] = backend
|
||||
post_auth_failed.send(sender, username=username, request=request, reason=reason)
|
||||
|
||||
|
||||
@@ -47,3 +47,9 @@ def clean_expire_token():
|
||||
count = TempToken.objects.filter(date_expired__lt=expired_time).delete()
|
||||
logging.info('Deleted %d temporary tokens.', count[0])
|
||||
logging.info('Cleaned expired temporary and connection tokens.')
|
||||
|
||||
|
||||
@register_as_period_task(crontab=CRONTAB_AT_AM_TWO)
|
||||
def clear_oauth2_provider_expired_tokens():
|
||||
from oauth2_provider.models import clear_expired
|
||||
clear_expired()
|
||||
@@ -376,7 +376,7 @@
|
||||
</div>
|
||||
{% if form.challenge %}
|
||||
{% bootstrap_field form.challenge show_label=False %}
|
||||
{% elif form.mfa_type and mfa_backends %}
|
||||
{% elif form.mfa_type %}
|
||||
<div class="form-group" style="display: flex">
|
||||
{% include '_mfa_login_field.html' %}
|
||||
</div>
|
||||
|
||||
@@ -16,6 +16,7 @@ router.register('super-connection-token', api.SuperConnectionTokenViewSet, 'supe
|
||||
router.register('admin-connection-token', api.AdminConnectionTokenViewSet, 'admin-connection-token')
|
||||
router.register('confirm', api.UserConfirmationViewSet, 'confirm')
|
||||
router.register('ssh-key', api.SSHkeyViewSet, 'ssh-key')
|
||||
router.register('access-tokens', api.AccessTokenViewSet, 'access-token')
|
||||
|
||||
urlpatterns = [
|
||||
path('<str:backend>/qr/unbind/', api.QRUnBindForUserApi.as_view(), name='qr-unbind'),
|
||||
|
||||
@@ -83,4 +83,6 @@ urlpatterns = [
|
||||
path('oauth2/', include(('authentication.backends.oauth2.urls', 'authentication'), namespace='oauth2')),
|
||||
|
||||
path('captcha/', include('captcha.urls')),
|
||||
|
||||
path('oauth2-provider/', include(('authentication.backends.oauth2_provider.urls', 'authentication'), namespace='oauth2-provider'))
|
||||
]
|
||||
|
||||
@@ -11,6 +11,7 @@ from rest_framework.request import Request
|
||||
from authentication import errors
|
||||
from authentication.mixins import AuthMixin
|
||||
from authentication.notifications import OAuthBindMessage
|
||||
from authentication.decorators import post_save_next_to_session_if_guard_redirect
|
||||
from common.utils import get_logger
|
||||
from common.utils.common import get_request_ip
|
||||
from common.utils.django import reverse, get_object_or_none
|
||||
@@ -72,6 +73,7 @@ class BaseLoginCallbackView(AuthMixin, FlashMessageMixin, IMClientMixin, View):
|
||||
|
||||
return user, None
|
||||
|
||||
@post_save_next_to_session_if_guard_redirect
|
||||
def get(self, request: Request):
|
||||
code = request.GET.get('code')
|
||||
redirect_url = request.GET.get('redirect_url')
|
||||
@@ -110,8 +112,6 @@ class BaseLoginCallbackView(AuthMixin, FlashMessageMixin, IMClientMixin, View):
|
||||
response = self.get_failed_response(login_url, title=msg, msg=msg)
|
||||
return response
|
||||
|
||||
if redirect_url and 'next=client' in redirect_url:
|
||||
self.request.META['QUERY_STRING'] += '&next=client'
|
||||
return self.redirect_to_guard_view()
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from django.views import View
|
||||
from rest_framework.exceptions import APIException
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
|
||||
from authentication.decorators import post_save_next_to_session_if_guard_redirect, pre_save_next_to_session
|
||||
from authentication import errors
|
||||
from authentication.const import ConfirmType
|
||||
from authentication.mixins import AuthMixin
|
||||
@@ -24,7 +25,7 @@ from common.views.mixins import PermissionsMixin, UserConfirmRequiredExceptionMi
|
||||
from users.models import User
|
||||
from users.views import UserVerifyPasswordView
|
||||
from .base import BaseLoginCallbackView
|
||||
from .mixins import METAMixin, FlashMessageMixin
|
||||
from .mixins import FlashMessageMixin
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
@@ -171,20 +172,18 @@ class DingTalkEnableStartView(UserVerifyPasswordView):
|
||||
return success_url
|
||||
|
||||
|
||||
class DingTalkQRLoginView(DingTalkQRMixin, METAMixin, View):
|
||||
class DingTalkQRLoginView(DingTalkQRMixin, View):
|
||||
permission_classes = (AllowAny,)
|
||||
|
||||
@pre_save_next_to_session()
|
||||
def get(self, request: HttpRequest):
|
||||
redirect_url = request.GET.get('redirect_url') or reverse('index')
|
||||
query_string = request.GET.urlencode()
|
||||
redirect_url = f'{redirect_url}?{query_string}'
|
||||
next_url = self.get_next_url_from_meta() or reverse('index')
|
||||
next_url = safe_next_url(next_url, request=request)
|
||||
|
||||
redirect_uri = reverse('authentication:dingtalk-qr-login-callback', external=True)
|
||||
redirect_uri += '?' + urlencode({
|
||||
'redirect_url': redirect_url,
|
||||
'next': next_url,
|
||||
})
|
||||
|
||||
url = self.get_qr_url(redirect_uri)
|
||||
@@ -210,6 +209,7 @@ class DingTalkQRLoginCallbackView(DingTalkQRMixin, BaseLoginCallbackView):
|
||||
class DingTalkOAuthLoginView(DingTalkOAuthMixin, View):
|
||||
permission_classes = (AllowAny,)
|
||||
|
||||
@pre_save_next_to_session()
|
||||
def get(self, request: HttpRequest):
|
||||
redirect_url = request.GET.get('redirect_url')
|
||||
|
||||
@@ -223,6 +223,7 @@ class DingTalkOAuthLoginView(DingTalkOAuthMixin, View):
|
||||
class DingTalkOAuthLoginCallbackView(AuthMixin, DingTalkOAuthMixin, View):
|
||||
permission_classes = (AllowAny,)
|
||||
|
||||
@post_save_next_to_session_if_guard_redirect
|
||||
def get(self, request: HttpRequest):
|
||||
code = request.GET.get('code')
|
||||
redirect_url = request.GET.get('redirect_url')
|
||||
|
||||
@@ -8,6 +8,7 @@ from django.views import View
|
||||
from rest_framework.exceptions import APIException
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
|
||||
from authentication.decorators import pre_save_next_to_session
|
||||
from authentication.const import ConfirmType
|
||||
from authentication.permissions import UserConfirmation
|
||||
from common.sdk.im.feishu import URL
|
||||
@@ -108,9 +109,12 @@ class FeiShuQRBindCallbackView(FeiShuQRMixin, BaseBindCallbackView):
|
||||
class FeiShuQRLoginView(FeiShuQRMixin, View):
|
||||
permission_classes = (AllowAny,)
|
||||
|
||||
@pre_save_next_to_session()
|
||||
def get(self, request: HttpRequest):
|
||||
redirect_url = request.GET.get('redirect_url') or reverse('index')
|
||||
query_string = request.GET.urlencode()
|
||||
query_string = request.GET.copy()
|
||||
query_string.pop('next', None)
|
||||
query_string = query_string.urlencode()
|
||||
redirect_url = f'{redirect_url}?{query_string}'
|
||||
redirect_uri = reverse(f'authentication:{self.category}-qr-login-callback', external=True)
|
||||
redirect_uri += '?' + urlencode({
|
||||
|
||||
@@ -29,7 +29,7 @@ from users.utils import (
|
||||
redirect_user_first_login_or_index
|
||||
)
|
||||
from .. import mixins, errors
|
||||
from ..const import RSA_PRIVATE_KEY, RSA_PUBLIC_KEY
|
||||
from ..const import RSA_PRIVATE_KEY, RSA_PUBLIC_KEY, USER_LOGIN_GUARD_VIEW_REDIRECT_FIELD
|
||||
from ..forms import get_user_login_form_cls
|
||||
from ..utils import get_auth_methods
|
||||
|
||||
@@ -260,7 +260,7 @@ class UserLoginView(mixins.AuthMixin, UserLoginContextMixin, FormView):
|
||||
|
||||
|
||||
class UserLoginGuardView(mixins.AuthMixin, RedirectView):
|
||||
redirect_field_name = 'next'
|
||||
redirect_field_name = USER_LOGIN_GUARD_VIEW_REDIRECT_FIELD
|
||||
login_url = reverse_lazy('authentication:login')
|
||||
login_mfa_url = reverse_lazy('authentication:login-mfa')
|
||||
login_confirm_url = reverse_lazy('authentication:login-wait-confirm')
|
||||
|
||||
@@ -67,6 +67,7 @@ class UserLoginMFAView(mixins.AuthMixin, FormView):
|
||||
def get_context_data(self, **kwargs):
|
||||
user = self.get_user_from_session()
|
||||
mfa_context = self.get_user_mfa_context(user)
|
||||
print(mfa_context)
|
||||
kwargs.update(mfa_context)
|
||||
return kwargs
|
||||
|
||||
|
||||
@@ -4,17 +4,6 @@ from django.utils.translation import gettext_lazy as _
|
||||
from common.utils import FlashMessageUtil
|
||||
|
||||
|
||||
class METAMixin:
|
||||
def get_next_url_from_meta(self):
|
||||
request_meta = self.request.META or {}
|
||||
next_url = None
|
||||
referer = request_meta.get('HTTP_REFERER', '')
|
||||
next_url_item = referer.rsplit('next=', 1)
|
||||
if len(next_url_item) > 1:
|
||||
next_url = next_url_item[-1]
|
||||
return next_url
|
||||
|
||||
|
||||
class FlashMessageMixin:
|
||||
@staticmethod
|
||||
def get_response(redirect_url='', title='', msg='', m_type='message', interval=5):
|
||||
|
||||
@@ -8,6 +8,7 @@ from rest_framework.exceptions import APIException
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rest_framework.request import Request
|
||||
|
||||
from authentication.decorators import pre_save_next_to_session
|
||||
from authentication.const import ConfirmType
|
||||
from authentication.permissions import UserConfirmation
|
||||
from common.sdk.im.slack import URL, SLACK_REDIRECT_URI_SESSION_KEY
|
||||
@@ -96,9 +97,12 @@ class SlackEnableStartView(UserVerifyPasswordView):
|
||||
class SlackQRLoginView(SlackMixin, View):
|
||||
permission_classes = (AllowAny,)
|
||||
|
||||
@pre_save_next_to_session()
|
||||
def get(self, request: Request):
|
||||
redirect_url = request.GET.get('redirect_url') or reverse('index')
|
||||
query_string = request.GET.urlencode()
|
||||
query_string = request.GET.copy()
|
||||
query_string.pop('next', None)
|
||||
query_string = query_string.urlencode()
|
||||
redirect_url = f'{redirect_url}?{query_string}'
|
||||
redirect_uri = reverse('authentication:slack-qr-login-callback', external=True)
|
||||
redirect_uri += '?' + urlencode({
|
||||
|
||||
@@ -12,6 +12,7 @@ from authentication import errors
|
||||
from authentication.const import ConfirmType
|
||||
from authentication.mixins import AuthMixin
|
||||
from authentication.permissions import UserConfirmation
|
||||
from authentication.decorators import post_save_next_to_session_if_guard_redirect, pre_save_next_to_session
|
||||
from common.sdk.im.wecom import URL
|
||||
from common.sdk.im.wecom import WeCom, wecom_tool
|
||||
from common.utils import get_logger
|
||||
@@ -20,7 +21,7 @@ from common.views.mixins import UserConfirmRequiredExceptionMixin, PermissionsMi
|
||||
from users.models import User
|
||||
from users.views import UserVerifyPasswordView
|
||||
from .base import BaseLoginCallbackView, BaseBindCallbackView
|
||||
from .mixins import METAMixin, FlashMessageMixin
|
||||
from .mixins import FlashMessageMixin
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
@@ -115,19 +116,14 @@ class WeComEnableStartView(UserVerifyPasswordView):
|
||||
return success_url
|
||||
|
||||
|
||||
class WeComQRLoginView(WeComQRMixin, METAMixin, View):
|
||||
class WeComQRLoginView(WeComQRMixin, View):
|
||||
permission_classes = (AllowAny,)
|
||||
|
||||
@pre_save_next_to_session()
|
||||
def get(self, request: HttpRequest):
|
||||
redirect_url = request.GET.get('redirect_url') or reverse('index')
|
||||
next_url = self.get_next_url_from_meta() or reverse('index')
|
||||
next_url = safe_next_url(next_url, request=request)
|
||||
redirect_uri = reverse('authentication:wecom-qr-login-callback', external=True)
|
||||
redirect_uri += '?' + urlencode({
|
||||
'redirect_url': redirect_url,
|
||||
'next': next_url,
|
||||
})
|
||||
|
||||
redirect_uri += '?' + urlencode({'redirect_url': redirect_url})
|
||||
url = self.get_qr_url(redirect_uri)
|
||||
return HttpResponseRedirect(url)
|
||||
|
||||
@@ -148,12 +144,11 @@ class WeComQRLoginCallbackView(WeComQRMixin, BaseLoginCallbackView):
|
||||
class WeComOAuthLoginView(WeComOAuthMixin, View):
|
||||
permission_classes = (AllowAny,)
|
||||
|
||||
@pre_save_next_to_session()
|
||||
def get(self, request: HttpRequest):
|
||||
redirect_url = request.GET.get('redirect_url')
|
||||
|
||||
redirect_uri = reverse('authentication:wecom-oauth-login-callback', external=True)
|
||||
redirect_uri += '?' + urlencode({'redirect_url': redirect_url})
|
||||
|
||||
url = self.get_oauth_url(redirect_uri)
|
||||
return HttpResponseRedirect(url)
|
||||
|
||||
@@ -161,6 +156,7 @@ class WeComOAuthLoginView(WeComOAuthMixin, View):
|
||||
class WeComOAuthLoginCallbackView(AuthMixin, WeComOAuthMixin, View):
|
||||
permission_classes = (AllowAny,)
|
||||
|
||||
@post_save_next_to_session_if_guard_redirect
|
||||
def get(self, request: HttpRequest):
|
||||
code = request.GET.get('code')
|
||||
redirect_url = request.GET.get('redirect_url')
|
||||
|
||||
@@ -183,6 +183,7 @@ class BaseFileRenderer(LogMixin, BaseRenderer):
|
||||
for item in data:
|
||||
row = []
|
||||
for field in render_fields:
|
||||
field._row = item
|
||||
value = item.get(field.field_name)
|
||||
value = self.render_value(field, value)
|
||||
row.append(value)
|
||||
|
||||
@@ -15,6 +15,7 @@ class Device:
|
||||
self.__load_driver(driver_path)
|
||||
# open device
|
||||
self.__open_device()
|
||||
self.__reset_key_store()
|
||||
|
||||
def close(self):
|
||||
if self.__device is None:
|
||||
@@ -68,3 +69,12 @@ class Device:
|
||||
if ret != 0:
|
||||
raise PiicoError("open piico device failed", ret)
|
||||
self.__device = device
|
||||
|
||||
def __reset_key_store(self):
|
||||
if self._driver is None:
|
||||
raise PiicoError("no driver loaded", 0)
|
||||
if self.__device is None:
|
||||
raise PiicoError("device not open", 0)
|
||||
ret = self._driver.SPII_ResetModule(self.__device)
|
||||
if ret != 0:
|
||||
raise PiicoError("reset device failed", ret)
|
||||
|
||||
@@ -192,6 +192,7 @@ class WeCom(RequestMixin):
|
||||
class WeComTool(object):
|
||||
WECOM_STATE_SESSION_KEY = '_wecom_state'
|
||||
WECOM_STATE_VALUE = 'wecom'
|
||||
WECOM_STATE_NEXT_URL_KEY = 'wecom_oauth_next_url'
|
||||
|
||||
@lazyproperty
|
||||
def qr_cb_url(self):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -280,7 +280,8 @@
|
||||
"CACertificate": "Ca certificate",
|
||||
"CAS": "CAS",
|
||||
"CMPP2": "Cmpp v2.0",
|
||||
"CTYunPrivate": "eCloud Private Cloud",
|
||||
"CTYun": "State Cloud",
|
||||
"CTYunPrivate": "State Cloud(Private)",
|
||||
"CalculationResults": "Error in cron expression",
|
||||
"CallRecords": "Call Records",
|
||||
"CanDragSelect": "Select by dragging; Empty means all selected",
|
||||
@@ -1634,5 +1635,8 @@
|
||||
"selectedAssets": "Selected assets",
|
||||
"setVariable": "Set variable",
|
||||
"userId": "User ID",
|
||||
"userName": "User name"
|
||||
}
|
||||
"userName": "User name",
|
||||
"AccessToken": "Access tokens",
|
||||
"AccessTokenTip": "Access Token is a temporary credential generated through the OAuth2 (Authorization Code Grant) flow using the JumpServer client, which is used to access protected resources.",
|
||||
"Revoke": "Revoke"
|
||||
}
|
||||
|
||||
@@ -279,6 +279,7 @@
|
||||
"CACertificate": "CA 证书",
|
||||
"CAS": "CAS",
|
||||
"CMPP2": "CMPP v2.0",
|
||||
"CTYun": "天翼云",
|
||||
"CTYunPrivate": "天翼私有云",
|
||||
"CalculationResults": "cron 表达式错误",
|
||||
"CallRecords": "调用记录",
|
||||
@@ -1644,5 +1645,8 @@
|
||||
"userId": "用户ID",
|
||||
"userName": "用户名",
|
||||
"Risk": "风险",
|
||||
"selectFiles": "已选择选择{number}文件"
|
||||
}
|
||||
"selectFiles": "已选择选择{number}文件",
|
||||
"AccessToken": "访问令牌",
|
||||
"AccessTokenTip": "访问令牌是通过 JumpServer 客户端使用 OAuth2(授权码授权)流程生成的临时凭证,用于访问受保护的资源。",
|
||||
"Revoke": "撤销"
|
||||
}
|
||||
|
||||
@@ -381,7 +381,6 @@ class Config(dict):
|
||||
'CAS_USERNAME_ATTRIBUTE': 'cas:user',
|
||||
'CAS_APPLY_ATTRIBUTES_TO_USER': False,
|
||||
'CAS_RENAME_ATTRIBUTES': {'cas:user': 'username'},
|
||||
'CAS_CREATE_USER': True,
|
||||
'CAS_ORG_IDS': [DEFAULT_ID],
|
||||
|
||||
'AUTH_SSO': False,
|
||||
@@ -569,7 +568,7 @@ class Config(dict):
|
||||
'SAFE_MODE': False,
|
||||
'SECURITY_MFA_AUTH': 0, # 0 不开启 1 全局开启 2 管理员开启
|
||||
'SECURITY_MFA_AUTH_ENABLED_FOR_THIRD_PARTY': True,
|
||||
'SECURITY_MFA_ENABLED_BACKENDS': [],
|
||||
'SECURITY_MFA_BY_EMAIL': False,
|
||||
'SECURITY_COMMAND_EXECUTION': False,
|
||||
'SECURITY_COMMAND_BLACKLIST': [
|
||||
'reboot', 'shutdown', 'poweroff', 'halt', 'dd', 'half', 'top'
|
||||
@@ -692,9 +691,9 @@ class Config(dict):
|
||||
'FTP_FILE_MAX_STORE': 0,
|
||||
|
||||
# API 分页
|
||||
'MAX_LIMIT_PER_PAGE': 10000, # 给导出用
|
||||
'MAX_LIMIT_PER_PAGE': 10000, # 给导出用
|
||||
'MAX_PAGE_SIZE': 1000,
|
||||
'DEFAULT_PAGE_SIZE': 200, # 给没有请求分页的用
|
||||
'DEFAULT_PAGE_SIZE': 200, # 给没有请求分页的用
|
||||
|
||||
'LIMIT_SUPER_PRIV': False,
|
||||
|
||||
@@ -702,15 +701,7 @@ class Config(dict):
|
||||
'CHAT_AI_ENABLED': False,
|
||||
'CHAT_AI_METHOD': 'api',
|
||||
'CHAT_AI_EMBED_URL': '',
|
||||
'CHAT_AI_TYPE': 'gpt',
|
||||
'GPT_BASE_URL': '',
|
||||
'GPT_API_KEY': '',
|
||||
'GPT_PROXY': '',
|
||||
'GPT_MODEL': 'gpt-4o-mini',
|
||||
'DEEPSEEK_BASE_URL': '',
|
||||
'DEEPSEEK_API_KEY': '',
|
||||
'DEEPSEEK_PROXY': '',
|
||||
'DEEPSEEK_MODEL': 'deepseek-chat',
|
||||
'CHAT_AI_PROVIDERS': [],
|
||||
'VIRTUAL_APP_ENABLED': False,
|
||||
|
||||
'FILE_UPLOAD_SIZE_LIMIT_MB': 200,
|
||||
@@ -735,6 +726,10 @@ class Config(dict):
|
||||
|
||||
# MCP
|
||||
'MCP_ENABLED': False,
|
||||
|
||||
# oauth2_provider settings
|
||||
'OAUTH2_PROVIDER_ACCESS_TOKEN_EXPIRE_SECONDS': 60 * 60,
|
||||
'OAUTH2_PROVIDER_REFRESH_TOKEN_EXPIRE_SECONDS': 60 * 60 * 24 * 7,
|
||||
}
|
||||
|
||||
old_config_map = {
|
||||
|
||||
@@ -151,8 +151,13 @@ class SafeRedirectMiddleware:
|
||||
|
||||
if not (300 <= response.status_code < 400):
|
||||
return response
|
||||
if request.resolver_match and request.resolver_match.namespace.startswith('authentication'):
|
||||
# 认证相关的路由跳过验证(core/auth/xxxx
|
||||
if (
|
||||
request.resolver_match and
|
||||
request.resolver_match.namespace.startswith('authentication') and
|
||||
not request.resolver_match.namespace.startswith('authentication:oauth2-provider')
|
||||
):
|
||||
# 认证相关的路由跳过验证 /core/auth/...,
|
||||
# 但 oauth2-provider 除外, 因为它会重定向到第三方客户端, 希望给出更友好的提示
|
||||
return response
|
||||
location = response.get('Location')
|
||||
if not location:
|
||||
|
||||
@@ -159,7 +159,8 @@ CAS_CHECK_NEXT = lambda _next_page: True
|
||||
CAS_USERNAME_ATTRIBUTE = CONFIG.CAS_USERNAME_ATTRIBUTE
|
||||
CAS_APPLY_ATTRIBUTES_TO_USER = CONFIG.CAS_APPLY_ATTRIBUTES_TO_USER
|
||||
CAS_RENAME_ATTRIBUTES = CONFIG.CAS_RENAME_ATTRIBUTES
|
||||
CAS_CREATE_USER = CONFIG.CAS_CREATE_USER
|
||||
CAS_CREATE_USER = True
|
||||
CAS_STORE_NEXT = True
|
||||
|
||||
# SSO auth
|
||||
AUTH_SSO = CONFIG.AUTH_SSO
|
||||
|
||||
@@ -130,6 +130,7 @@ INSTALLED_APPS = [
|
||||
'settings.apps.SettingsConfig',
|
||||
'terminal.apps.TerminalConfig',
|
||||
'audits.apps.AuditsConfig',
|
||||
'oauth2_provider',
|
||||
'authentication.apps.AuthenticationConfig', # authentication
|
||||
'tickets.apps.TicketsConfig',
|
||||
'acls.apps.AclsConfig',
|
||||
|
||||
@@ -241,15 +241,7 @@ ASSET_SIZE = 'small'
|
||||
CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED
|
||||
CHAT_AI_METHOD = CONFIG.CHAT_AI_METHOD
|
||||
CHAT_AI_EMBED_URL = CONFIG.CHAT_AI_EMBED_URL
|
||||
CHAT_AI_TYPE = CONFIG.CHAT_AI_TYPE
|
||||
GPT_BASE_URL = CONFIG.GPT_BASE_URL
|
||||
GPT_API_KEY = CONFIG.GPT_API_KEY
|
||||
GPT_PROXY = CONFIG.GPT_PROXY
|
||||
GPT_MODEL = CONFIG.GPT_MODEL
|
||||
DEEPSEEK_BASE_URL = CONFIG.DEEPSEEK_BASE_URL
|
||||
DEEPSEEK_API_KEY = CONFIG.DEEPSEEK_API_KEY
|
||||
DEEPSEEK_PROXY = CONFIG.DEEPSEEK_PROXY
|
||||
DEEPSEEK_MODEL = CONFIG.DEEPSEEK_MODEL
|
||||
CHAT_AI_DEFAULT_PROVIDER = CONFIG.CHAT_AI_DEFAULT_PROVIDER
|
||||
|
||||
VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED
|
||||
|
||||
@@ -268,4 +260,6 @@ LOKI_BASE_URL = CONFIG.LOKI_BASE_URL
|
||||
TOOL_USER_ENABLED = CONFIG.TOOL_USER_ENABLED
|
||||
|
||||
SUGGESTION_LIMIT = CONFIG.SUGGESTION_LIMIT
|
||||
MCP_ENABLED = CONFIG.MCP_ENABLED
|
||||
MCP_ENABLED = CONFIG.MCP_ENABLED
|
||||
CHAT_AI_PROVIDERS = CONFIG.CHAT_AI_PROVIDERS
|
||||
|
||||
|
||||
@@ -30,10 +30,11 @@ REST_FRAMEWORK = {
|
||||
),
|
||||
'DEFAULT_AUTHENTICATION_CLASSES': (
|
||||
# 'rest_framework.authentication.BasicAuthentication',
|
||||
'authentication.backends.drf.AccessTokenAuthentication',
|
||||
'authentication.backends.drf.PrivateTokenAuthentication',
|
||||
'authentication.backends.drf.ServiceAuthentication',
|
||||
'authentication.backends.drf.SignatureAuthentication',
|
||||
'authentication.backends.drf.ServiceAuthentication',
|
||||
'authentication.backends.drf.PrivateTokenAuthentication',
|
||||
'authentication.backends.drf.AccessTokenAuthentication',
|
||||
"oauth2_provider.contrib.rest_framework.OAuth2Authentication",
|
||||
'authentication.backends.drf.SessionAuthentication',
|
||||
),
|
||||
'DEFAULT_FILTER_BACKENDS': (
|
||||
@@ -222,3 +223,17 @@ PIICO_DRIVER_PATH = CONFIG.PIICO_DRIVER_PATH
|
||||
LEAK_PASSWORD_DB_PATH = CONFIG.LEAK_PASSWORD_DB_PATH
|
||||
|
||||
JUMPSERVER_UPTIME = int(time.time())
|
||||
|
||||
# OAuth2 Provider settings
|
||||
OAUTH2_PROVIDER = {
|
||||
'ALLOWED_REDIRECT_URI_SCHEMES': ['https', 'jms'],
|
||||
'PKCE_REQUIRED': True,
|
||||
'ACCESS_TOKEN_EXPIRE_SECONDS': CONFIG.OAUTH2_PROVIDER_ACCESS_TOKEN_EXPIRE_SECONDS,
|
||||
'REFRESH_TOKEN_EXPIRE_SECONDS': CONFIG.OAUTH2_PROVIDER_REFRESH_TOKEN_EXPIRE_SECONDS,
|
||||
}
|
||||
OAUTH2_PROVIDER_CLIENT_REDIRECT_URI = 'jms://auth/callback'
|
||||
OAUTH2_PROVIDER_JUMPSERVER_CLIENT_NAME = 'JumpServer Client'
|
||||
|
||||
if CONFIG.DEBUG_DEV:
|
||||
OAUTH2_PROVIDER['ALLOWED_REDIRECT_URI_SCHEMES'].append('http')
|
||||
OAUTH2_PROVIDER_CLIENT_REDIRECT_URI += ' http://127.0.0.1:14876/auth/callback'
|
||||
|
||||
@@ -148,6 +148,6 @@ class RedirectConfirm(TemplateView):
|
||||
parsed = urlparse(url)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
return False
|
||||
if parsed.scheme not in ['http', 'https']:
|
||||
if parsed.scheme not in ['http', 'https', 'jms']:
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -133,6 +133,11 @@ exclude_permissions = (
|
||||
('terminal', 'session', 'delete,change', 'command'),
|
||||
('applications', '*', '*', '*'),
|
||||
('settings', 'chatprompt', 'add,delete,change', 'chatprompt'),
|
||||
('oauth2_provider', 'grant', '*', '*'),
|
||||
('oauth2_provider', 'refreshtoken', '*', '*'),
|
||||
('oauth2_provider', 'idtoken', '*', '*'),
|
||||
('oauth2_provider', 'application', '*', '*'),
|
||||
('oauth2_provider', 'accesstoken', 'add,change', 'accesstoken')
|
||||
)
|
||||
|
||||
only_system_permissions = (
|
||||
@@ -160,6 +165,7 @@ only_system_permissions = (
|
||||
('authentication', 'temptoken', '*', '*'),
|
||||
('authentication', 'passkey', '*', '*'),
|
||||
('authentication', 'ssotoken', '*', '*'),
|
||||
('oauth2_provider', 'accesstoken', '*', '*'),
|
||||
('tickets', '*', '*', '*'),
|
||||
('orgs', 'organization', 'view', 'rootorg'),
|
||||
('terminal', 'applet', '*', '*'),
|
||||
|
||||
@@ -129,6 +129,7 @@ special_pid_mapper = {
|
||||
"rbac.view_systemtools": "view_workbench",
|
||||
'tickets.view_ticket': 'tickets',
|
||||
"audits.joblog": "job_audit",
|
||||
'oauth2_provider.accesstoken': 'authentication',
|
||||
}
|
||||
|
||||
special_setting_pid_mapper = {
|
||||
@@ -184,6 +185,11 @@ verbose_name_mapper = {
|
||||
'tickets.view_ticket': _("Ticket"),
|
||||
'settings.setting': _("Common setting"),
|
||||
'rbac.view_permission': _('View permission tree'),
|
||||
'authentication.passkey': _("Passkey"),
|
||||
'oauth2_provider.accesstoken': _("Access token"),
|
||||
'oauth2_provider.view_accesstoken': _("View access token"),
|
||||
'oauth2_provider.delete_accesstoken': _("Revoke access token"),
|
||||
|
||||
}
|
||||
|
||||
xpack_nodes = [
|
||||
|
||||
@@ -1,98 +1,10 @@
|
||||
import httpx
|
||||
import openai
|
||||
from django.conf import settings
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import status
|
||||
from rest_framework.generics import GenericAPIView
|
||||
from rest_framework.response import Response
|
||||
|
||||
from common.api import JMSModelViewSet
|
||||
from common.permissions import IsValidUser, OnlySuperUser
|
||||
from .. import serializers
|
||||
from ..const import ChatAITypeChoices
|
||||
from ..models import ChatPrompt
|
||||
from ..prompt import DefaultChatPrompt
|
||||
|
||||
|
||||
class ChatAITestingAPI(GenericAPIView):
|
||||
serializer_class = serializers.ChatAISettingSerializer
|
||||
rbac_perms = {
|
||||
'POST': 'settings.change_chatai'
|
||||
}
|
||||
|
||||
def get_config(self, request):
|
||||
serializer = self.serializer_class(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
data = self.serializer_class().data
|
||||
data.update(serializer.validated_data)
|
||||
for k, v in data.items():
|
||||
if v:
|
||||
continue
|
||||
# 页面没有传递值, 从 settings 中获取
|
||||
data[k] = getattr(settings, k, None)
|
||||
return data
|
||||
|
||||
def post(self, request):
|
||||
config = self.get_config(request)
|
||||
chat_ai_enabled = config['CHAT_AI_ENABLED']
|
||||
if not chat_ai_enabled:
|
||||
return Response(
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
data={'msg': _('Chat AI is not enabled')}
|
||||
)
|
||||
|
||||
tp = config['CHAT_AI_TYPE']
|
||||
if tp == ChatAITypeChoices.gpt:
|
||||
url = config['GPT_BASE_URL']
|
||||
api_key = config['GPT_API_KEY']
|
||||
proxy = config['GPT_PROXY']
|
||||
model = config['GPT_MODEL']
|
||||
else:
|
||||
url = config['DEEPSEEK_BASE_URL']
|
||||
api_key = config['DEEPSEEK_API_KEY']
|
||||
proxy = config['DEEPSEEK_PROXY']
|
||||
model = config['DEEPSEEK_MODEL']
|
||||
|
||||
kwargs = {
|
||||
'base_url': url or None,
|
||||
'api_key': api_key,
|
||||
}
|
||||
try:
|
||||
if proxy:
|
||||
kwargs['http_client'] = httpx.Client(
|
||||
proxies=proxy,
|
||||
transport=httpx.HTTPTransport(local_address='0.0.0.0')
|
||||
)
|
||||
client = openai.OpenAI(**kwargs)
|
||||
|
||||
ok = False
|
||||
error = ''
|
||||
|
||||
client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Say this is a test",
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
)
|
||||
ok = True
|
||||
except openai.APIConnectionError as e:
|
||||
error = str(e.__cause__) # an underlying Exception, likely raised within httpx.
|
||||
except openai.APIStatusError as e:
|
||||
error = str(e.message)
|
||||
except Exception as e:
|
||||
ok, error = False, str(e)
|
||||
|
||||
if ok:
|
||||
_status, msg = status.HTTP_200_OK, _('Test success')
|
||||
else:
|
||||
_status, msg = status.HTTP_400_BAD_REQUEST, error
|
||||
|
||||
return Response(status=_status, data={'msg': msg})
|
||||
|
||||
|
||||
class ChatPromptViewSet(JMSModelViewSet):
|
||||
serializer_classes = {
|
||||
'default': serializers.ChatPromptSerializer,
|
||||
|
||||
@@ -154,7 +154,10 @@ class SettingsApi(generics.RetrieveUpdateAPIView):
|
||||
def parse_serializer_data(self, serializer):
|
||||
data = []
|
||||
fields = self.get_fields()
|
||||
encrypted_items = [name for name, field in fields.items() if field.write_only]
|
||||
encrypted_items = [
|
||||
name for name, field in fields.items()
|
||||
if field.write_only or getattr(field, 'encrypted', False)
|
||||
]
|
||||
category = self.request.query_params.get('category', '')
|
||||
for name, value in serializer.validated_data.items():
|
||||
encrypted = name in encrypted_items
|
||||
|
||||
@@ -14,18 +14,5 @@ class ChatAIMethodChoices(TextChoices):
|
||||
|
||||
|
||||
class ChatAITypeChoices(TextChoices):
|
||||
gpt = 'gpt', 'GPT'
|
||||
deep_seek = 'deep-seek', 'DeepSeek'
|
||||
|
||||
|
||||
class GPTModelChoices(TextChoices):
|
||||
gpt_4o_mini = 'gpt-4o-mini', 'gpt-4o-mini'
|
||||
gpt_4o = 'gpt-4o', 'gpt-4o'
|
||||
o3_mini = 'o3-mini', 'o3-mini'
|
||||
o1_mini = 'o1-mini', 'o1-mini'
|
||||
o1 = 'o1', 'o1'
|
||||
|
||||
|
||||
class DeepSeekModelChoices(TextChoices):
|
||||
deepseek_chat = 'deepseek-chat', 'DeepSeek-V3'
|
||||
deepseek_reasoner = 'deepseek-reasoner', 'DeepSeek-R1'
|
||||
openai = 'openai', 'Openai'
|
||||
ollama = 'ollama', 'Ollama'
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.files.base import ContentFile
|
||||
@@ -14,7 +15,6 @@ from rest_framework.utils.encoders import JSONEncoder
|
||||
from common.db.models import JMSBaseModel
|
||||
from common.db.utils import Encryptor
|
||||
from common.utils import get_logger
|
||||
from .const import ChatAITypeChoices
|
||||
from .signals import setting_changed
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -196,20 +196,25 @@ class ChatPrompt(JMSBaseModel):
|
||||
return self.name
|
||||
|
||||
|
||||
def get_chatai_data():
|
||||
data = {
|
||||
'url': settings.GPT_BASE_URL,
|
||||
'api_key': settings.GPT_API_KEY,
|
||||
'proxy': settings.GPT_PROXY,
|
||||
'model': settings.GPT_MODEL,
|
||||
}
|
||||
if settings.CHAT_AI_TYPE != ChatAITypeChoices.gpt:
|
||||
data['url'] = settings.DEEPSEEK_BASE_URL
|
||||
data['api_key'] = settings.DEEPSEEK_API_KEY
|
||||
data['proxy'] = settings.DEEPSEEK_PROXY
|
||||
data['model'] = settings.DEEPSEEK_MODEL
|
||||
def get_chatai_data() -> Dict[str, Any]:
|
||||
raw_providers = settings.CHAT_AI_PROVIDERS
|
||||
providers: List[dict] = [p for p in raw_providers if isinstance(p, dict)]
|
||||
|
||||
return data
|
||||
if not providers:
|
||||
return {}
|
||||
|
||||
selected = next(
|
||||
(p for p in providers if p.get('is_assistant')),
|
||||
providers[0],
|
||||
)
|
||||
|
||||
return {
|
||||
'url': selected.get('base_url'),
|
||||
'api_key': selected.get('api_key'),
|
||||
'proxy': selected.get('proxy'),
|
||||
'model': selected.get('model'),
|
||||
'name': selected.get('name'),
|
||||
}
|
||||
|
||||
|
||||
def init_sqlite_db():
|
||||
|
||||
@@ -37,11 +37,4 @@ class CASSettingSerializer(serializers.Serializer):
|
||||
"and the `value` is the JumpServer user attribute name"
|
||||
)
|
||||
)
|
||||
CAS_CREATE_USER = serializers.BooleanField(
|
||||
required=False, label=_('Create user'),
|
||||
help_text=_(
|
||||
'After successful user authentication, if the user does not exist, '
|
||||
'automatically create the user'
|
||||
)
|
||||
)
|
||||
CAS_ORG_IDS = OrgListField()
|
||||
CAS_ORG_IDS = OrgListField()
|
||||
|
||||
@@ -10,11 +10,12 @@ from common.utils import date_expired_default
|
||||
__all__ = [
|
||||
'AnnouncementSettingSerializer', 'OpsSettingSerializer', 'VaultSettingSerializer',
|
||||
'HashicorpKVSerializer', 'AzureKVSerializer', 'TicketSettingSerializer',
|
||||
'ChatAISettingSerializer', 'VirtualAppSerializer', 'AmazonSMSerializer',
|
||||
'ChatAIProviderSerializer', 'ChatAISettingSerializer',
|
||||
'VirtualAppSerializer', 'AmazonSMSerializer',
|
||||
]
|
||||
|
||||
from settings.const import (
|
||||
ChatAITypeChoices, GPTModelChoices, DeepSeekModelChoices, ChatAIMethodChoices
|
||||
ChatAITypeChoices, ChatAIMethodChoices
|
||||
)
|
||||
|
||||
|
||||
@@ -120,6 +121,29 @@ class AmazonSMSerializer(serializers.Serializer):
|
||||
)
|
||||
|
||||
|
||||
class ChatAIProviderListSerializer(serializers.ListSerializer):
|
||||
# 标记整个列表需要加密存储,避免明文保存 API Key
|
||||
encrypted = True
|
||||
|
||||
|
||||
class ChatAIProviderSerializer(serializers.Serializer):
|
||||
type = serializers.ChoiceField(
|
||||
default=ChatAITypeChoices.openai, choices=ChatAITypeChoices.choices,
|
||||
label=_("Types"), required=False,
|
||||
)
|
||||
base_url = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Base URL'),
|
||||
help_text=_('The base URL of the Chat service.')
|
||||
)
|
||||
api_key = EncryptedField(
|
||||
allow_blank=True, required=False, label=_('API Key'),
|
||||
)
|
||||
proxy = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Proxy'),
|
||||
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
|
||||
)
|
||||
|
||||
|
||||
class ChatAISettingSerializer(serializers.Serializer):
|
||||
PREFIX_TITLE = _('Chat AI')
|
||||
|
||||
@@ -130,44 +154,14 @@ class ChatAISettingSerializer(serializers.Serializer):
|
||||
default=ChatAIMethodChoices.api, choices=ChatAIMethodChoices.choices,
|
||||
label=_("Method"), required=False,
|
||||
)
|
||||
CHAT_AI_PROVIDERS = ChatAIProviderListSerializer(
|
||||
child=ChatAIProviderSerializer(),
|
||||
allow_empty=True, required=False, default=list, label=_('Providers')
|
||||
)
|
||||
CHAT_AI_EMBED_URL = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Base URL'),
|
||||
help_text=_('The base URL of the Chat service.')
|
||||
)
|
||||
CHAT_AI_TYPE = serializers.ChoiceField(
|
||||
default=ChatAITypeChoices.gpt, choices=ChatAITypeChoices.choices,
|
||||
label=_("Types"), required=False,
|
||||
)
|
||||
GPT_BASE_URL = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Base URL'),
|
||||
help_text=_('The base URL of the Chat service.')
|
||||
)
|
||||
GPT_API_KEY = EncryptedField(
|
||||
allow_blank=True, required=False, label=_('API Key'),
|
||||
)
|
||||
GPT_PROXY = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Proxy'),
|
||||
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
|
||||
)
|
||||
GPT_MODEL = serializers.ChoiceField(
|
||||
default=GPTModelChoices.gpt_4o_mini, choices=GPTModelChoices.choices,
|
||||
label=_("GPT Model"), required=False,
|
||||
)
|
||||
DEEPSEEK_BASE_URL = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Base URL'),
|
||||
help_text=_('The base URL of the Chat service.')
|
||||
)
|
||||
DEEPSEEK_API_KEY = EncryptedField(
|
||||
allow_blank=True, required=False, label=_('API Key'),
|
||||
)
|
||||
DEEPSEEK_PROXY = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Proxy'),
|
||||
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
|
||||
)
|
||||
DEEPSEEK_MODEL = serializers.ChoiceField(
|
||||
default=DeepSeekModelChoices.deepseek_chat, choices=DeepSeekModelChoices.choices,
|
||||
label=_("DeepSeek Model"), required=False,
|
||||
)
|
||||
|
||||
|
||||
class TicketSettingSerializer(serializers.Serializer):
|
||||
|
||||
@@ -73,8 +73,6 @@ class PrivateSettingSerializer(PublicSettingSerializer):
|
||||
CHAT_AI_ENABLED = serializers.BooleanField()
|
||||
CHAT_AI_METHOD = serializers.CharField()
|
||||
CHAT_AI_EMBED_URL = serializers.CharField()
|
||||
CHAT_AI_TYPE = serializers.CharField()
|
||||
GPT_MODEL = serializers.CharField()
|
||||
FILE_UPLOAD_SIZE_LIMIT_MB = serializers.IntegerField()
|
||||
FTP_FILE_MAX_STORE = serializers.IntegerField()
|
||||
LOKI_LOG_ENABLED = serializers.BooleanField()
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
from django.conf import settings
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
@@ -121,35 +117,6 @@ class SecurityLoginLimitSerializer(serializers.Serializer):
|
||||
)
|
||||
|
||||
|
||||
class DynamicMFAChoiceField(serializers.MultipleChoiceField):
|
||||
def __init__(self, **kwargs):
|
||||
_choices = self._get_dynamic_choices()
|
||||
super().__init__(choices=_choices, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _get_dynamic_choices():
|
||||
choices = []
|
||||
mfa_dir = os.path.join(settings.APPS_DIR, 'authentication', 'mfa')
|
||||
for filename in os.listdir(mfa_dir):
|
||||
if not filename.endswith('.py') or filename.startswith('__init__'):
|
||||
continue
|
||||
|
||||
module_name = f'authentication.mfa.{filename[:-3]}'
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ImportError:
|
||||
continue
|
||||
|
||||
for attr_name in dir(module):
|
||||
item = getattr(module, attr_name)
|
||||
if not isinstance(item, type) or not attr_name.startswith('MFA'):
|
||||
continue
|
||||
if 'BaseMFA' != item.__base__.__name__:
|
||||
continue
|
||||
choices.append((item.name, item.display_name))
|
||||
return choices
|
||||
|
||||
|
||||
class SecurityAuthSerializer(serializers.Serializer):
|
||||
SECURITY_MFA_AUTH = serializers.ChoiceField(
|
||||
choices=(
|
||||
@@ -163,10 +130,10 @@ class SecurityAuthSerializer(serializers.Serializer):
|
||||
required=False, default=True,
|
||||
label=_('Third-party login MFA'),
|
||||
)
|
||||
SECURITY_MFA_ENABLED_BACKENDS = DynamicMFAChoiceField(
|
||||
default=[], allow_empty=True,
|
||||
label=_('MFA Backends'),
|
||||
help_text=_('MFA methods supported for user login')
|
||||
SECURITY_MFA_BY_EMAIL = serializers.BooleanField(
|
||||
required=False, default=False,
|
||||
label=_('MFA via Email'),
|
||||
help_text=_('Email as a method for multi-factor authentication')
|
||||
)
|
||||
OTP_ISSUER_NAME = serializers.CharField(
|
||||
required=False, max_length=16, label=_('OTP issuer name'),
|
||||
|
||||
@@ -21,7 +21,6 @@ urlpatterns = [
|
||||
path('sms/<str:backend>/testing/', api.SMSTestingAPI.as_view(), name='sms-testing'),
|
||||
path('sms/backend/', api.SMSBackendAPI.as_view(), name='sms-backend'),
|
||||
path('vault/<str:backend>/testing/', api.VaultTestingAPI.as_view(), name='vault-testing'),
|
||||
path('chatai/testing/', api.ChatAITestingAPI.as_view(), name='chatai-testing'),
|
||||
path('vault/sync/', api.VaultSyncDataAPI.as_view(), name='vault-sync'),
|
||||
path('security/block-ip/', api.BlockIPSecurityAPI.as_view(), name='block-ip'),
|
||||
path('security/unlock-ip/', api.UnlockIPSecurityAPI.as_view(), name='unlock-ip'),
|
||||
|
||||
@@ -524,13 +524,16 @@ class LDAPTestUtil(object):
|
||||
# test server uri
|
||||
|
||||
def _check_server_uri(self):
|
||||
if not any([self.config.server_uri.startswith('ldap://') or
|
||||
self.config.server_uri.startswith('ldaps://')]):
|
||||
if not (self.config.server_uri.startswith('ldap://') or
|
||||
self.config.server_uri.startswith('ldaps://')):
|
||||
err = _('ldap:// or ldaps:// protocol is used.')
|
||||
raise LDAPInvalidServerError(err)
|
||||
|
||||
def _test_server_uri(self):
|
||||
self._test_connection_bind()
|
||||
# 这里测试 server uri 是否能连通, 不进行 bind 操作, 不需要传入 bind dn 和密码
|
||||
server = Server(self.config.server_uri, use_ssl=self.config.use_ssl)
|
||||
connection = Connection(server)
|
||||
connection.open()
|
||||
|
||||
def test_server_uri(self):
|
||||
try:
|
||||
|
||||
@@ -6,10 +6,6 @@
|
||||
|
||||
{% block content %}
|
||||
<style>
|
||||
.alert.alert-msg {
|
||||
background: #F5F5F7;
|
||||
}
|
||||
|
||||
.target-url {
|
||||
display: inline-block;
|
||||
max-width: 100%;
|
||||
@@ -18,29 +14,96 @@
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
vertical-align: middle;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
/* 重定向中的样式 */
|
||||
.redirecting-container {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.redirecting-container.show {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.confirm-container {
|
||||
transition: opacity 0.3s ease-out;
|
||||
}
|
||||
|
||||
.confirm-container.hide {
|
||||
display: none;
|
||||
}
|
||||
</style>
|
||||
<div>
|
||||
<p>
|
||||
<div class="alert {% if error %} alert-danger {% else %} alert-info {% endif %}" id="messages">
|
||||
{% trans 'You are about to be redirected to an external website. Please confirm that you trust this link: ' %}
|
||||
<a class="target-url" href="{{ target_url }}">{{ target_url }}</a>
|
||||
</div>
|
||||
</p>
|
||||
<!-- 确认内容 -->
|
||||
<div class="confirm-container" id="confirmContainer">
|
||||
<p>
|
||||
<div class="alert {% if error %} alert-danger {% else %} alert-info {% endif %}" id="messages">
|
||||
{% trans 'You are about to be redirected to an external website.' %}
|
||||
<br/>
|
||||
<br/>
|
||||
{% trans 'Please confirm that you trust this link: ' %}
|
||||
<br/>
|
||||
<br/>
|
||||
<a class="target-url" href="javascript:void(0)" onclick="handleRedirect(event)">{{ target_url }}</a>
|
||||
</div>
|
||||
</p>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-sm-3">
|
||||
<a href="/" class="btn btn-default block full-width m-b">
|
||||
{% trans 'Cancel' %}
|
||||
</a>
|
||||
</div>
|
||||
<div class="col-sm-3">
|
||||
<a href="{{ target_url }}" class="btn btn-primary block full-width m-b">
|
||||
{% trans 'Confirm' %}
|
||||
</a>
|
||||
<div class="row">
|
||||
<div class="col-sm-3">
|
||||
<a href="/" class="btn btn-default block full-width m-b">
|
||||
{% trans 'Back' %}
|
||||
</a>
|
||||
</div>
|
||||
<div class="col-sm-3">
|
||||
<a href="javascript:void(0)" onclick="handleRedirect(event)" class="btn btn-primary block full-width m-b">
|
||||
{% trans 'Confirm' %}
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 重定向中内容 -->
|
||||
<div class="redirecting-container" id="redirectingContainer">
|
||||
<p>
|
||||
<div class="alert alert-info" id="messages">
|
||||
{% trans 'Redirecting you to the Desktop App ( JumpServer Client )' %}
|
||||
<br/>
|
||||
<br/>
|
||||
{% trans 'You can safely close this window and return to the application.' %}
|
||||
</div>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
<script>
|
||||
const targetUrl = '{{ target_url }}';
|
||||
|
||||
// 判断是否是 jms:// 协议
|
||||
function isJmsProtocol(url) {
|
||||
return url.toLowerCase().startsWith('jms://');
|
||||
}
|
||||
|
||||
function handleRedirect(event) {
|
||||
// 如果有 event,阻止默认行为
|
||||
if (event) {
|
||||
event.preventDefault();
|
||||
}
|
||||
|
||||
if (isJmsProtocol(targetUrl)) {
|
||||
// 隐藏确认内容
|
||||
document.getElementById('confirmContainer').classList.add('hide');
|
||||
// 显示重定向中
|
||||
document.getElementById('redirectingContainer').classList.add('show');
|
||||
}
|
||||
|
||||
// 延迟后执行跳转(让用户看到加载动画)
|
||||
setTimeout(() => {
|
||||
window.location.href = targetUrl;
|
||||
}, 100);
|
||||
}
|
||||
</script>
|
||||
{% endblock %}
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from .applet import *
|
||||
from .chat import *
|
||||
from .component import *
|
||||
from .session import *
|
||||
from .virtualapp import *
|
||||
|
||||
1
apps/terminal/api/chat/__init__.py
Normal file
1
apps/terminal/api/chat/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .chat import *
|
||||
15
apps/terminal/api/chat/chat.py
Normal file
15
apps/terminal/api/chat/chat.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from common.api import JMSBulkModelViewSet
|
||||
from terminal import serializers
|
||||
from terminal.filters import ChatFilter
|
||||
from terminal.models import Chat
|
||||
|
||||
__all__ = ['ChatViewSet']
|
||||
|
||||
|
||||
class ChatViewSet(JMSBulkModelViewSet):
|
||||
queryset = Chat.objects.all()
|
||||
serializer_class = serializers.ChatSerializer
|
||||
filterset_class = ChatFilter
|
||||
search_fields = ['title']
|
||||
ordering_fields = ['date_updated']
|
||||
ordering = ['-date_updated']
|
||||
@@ -2,7 +2,7 @@ from django.db.models import QuerySet
|
||||
from django_filters import rest_framework as filters
|
||||
|
||||
from orgs.utils import filter_org_queryset
|
||||
from terminal.models import Command, CommandStorage, Session
|
||||
from terminal.models import Command, CommandStorage, Session, Chat
|
||||
|
||||
|
||||
class CommandFilter(filters.FilterSet):
|
||||
@@ -79,7 +79,34 @@ class CommandStorageFilter(filters.FilterSet):
|
||||
model = CommandStorage
|
||||
fields = ['real', 'name', 'type', 'is_default']
|
||||
|
||||
def filter_real(self, queryset, name, value):
|
||||
@staticmethod
|
||||
def filter_real(queryset, name, value):
|
||||
if value:
|
||||
queryset = queryset.exclude(name='null')
|
||||
return queryset
|
||||
|
||||
|
||||
class ChatFilter(filters.FilterSet):
|
||||
ids = filters.BooleanFilter(method='filter_ids')
|
||||
folder_ids = filters.BooleanFilter(method='filter_folder_ids')
|
||||
|
||||
|
||||
class Meta:
|
||||
model = Chat
|
||||
fields = [
|
||||
'title', 'user_id', 'pinned', 'folder_id',
|
||||
'archived', 'socket_id', 'share_id'
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def filter_ids(queryset, name, value):
|
||||
ids = value.split(',')
|
||||
queryset = queryset.filter(id__in=ids)
|
||||
return queryset
|
||||
|
||||
|
||||
@staticmethod
|
||||
def filter_folder_ids(queryset, name, value):
|
||||
ids = value.split(',')
|
||||
queryset = queryset.filter(folder_id__in=ids)
|
||||
return queryset
|
||||
|
||||
38
apps/terminal/migrations/0011_chat.py
Normal file
38
apps/terminal/migrations/0011_chat.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Generated by Django 4.1.13 on 2025-09-30 06:57
|
||||
|
||||
from django.db import migrations, models
|
||||
import uuid
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('terminal', '0010_alter_command_risk_level_alter_session_login_from_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='Chat',
|
||||
fields=[
|
||||
('created_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Created by')),
|
||||
('updated_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Updated by')),
|
||||
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
|
||||
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
|
||||
('comment', models.TextField(blank=True, default='', verbose_name='Comment')),
|
||||
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
|
||||
('title', models.CharField(max_length=256, verbose_name='Title')),
|
||||
('chat', models.JSONField(default=dict, verbose_name='Chat')),
|
||||
('meta', models.JSONField(default=dict, verbose_name='Meta')),
|
||||
('pinned', models.BooleanField(default=False, verbose_name='Pinned')),
|
||||
('archived', models.BooleanField(default=False, verbose_name='Archived')),
|
||||
('share_id', models.CharField(blank=True, default='', max_length=36)),
|
||||
('folder_id', models.CharField(blank=True, default='', max_length=36)),
|
||||
('socket_id', models.CharField(blank=True, default='', max_length=36)),
|
||||
('user_id', models.CharField(blank=True, db_index=True, default='', max_length=36)),
|
||||
('session_info', models.JSONField(default=dict, verbose_name='Session Info')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Chat',
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -1,4 +1,5 @@
|
||||
from .session import *
|
||||
from .component import *
|
||||
from .applet import *
|
||||
from .chat import *
|
||||
from .component import *
|
||||
from .session import *
|
||||
from .virtualapp import *
|
||||
|
||||
1
apps/terminal/models/chat/__init__.py
Normal file
1
apps/terminal/models/chat/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .chat import *
|
||||
30
apps/terminal/models/chat/chat.py
Normal file
30
apps/terminal/models/chat/chat.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common.db.models import JMSBaseModel
|
||||
from common.utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
__all__ = ['Chat']
|
||||
|
||||
|
||||
class Chat(JMSBaseModel):
|
||||
# id == session_id # 36 chars
|
||||
title = models.CharField(max_length=256, verbose_name=_('Title'))
|
||||
chat = models.JSONField(default=dict, verbose_name=_('Chat'))
|
||||
meta = models.JSONField(default=dict, verbose_name=_('Meta'))
|
||||
pinned = models.BooleanField(default=False, verbose_name=_('Pinned'))
|
||||
archived = models.BooleanField(default=False, verbose_name=_('Archived'))
|
||||
share_id = models.CharField(blank=True, default='', max_length=36)
|
||||
folder_id = models.CharField(blank=True, default='', max_length=36)
|
||||
socket_id = models.CharField(blank=True, default='', max_length=36)
|
||||
user_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
|
||||
|
||||
session_info = models.JSONField(default=dict, verbose_name=_('Session Info'))
|
||||
|
||||
class Meta:
|
||||
verbose_name = _('Chat')
|
||||
|
||||
def __str__(self):
|
||||
return self.title
|
||||
@@ -123,11 +123,10 @@ class Terminal(StorageMixin, TerminalStatusMixin, JMSBaseModel):
|
||||
def get_chat_ai_setting():
|
||||
data = get_chatai_data()
|
||||
return {
|
||||
'GPT_BASE_URL': data['url'],
|
||||
'GPT_API_KEY': data['api_key'],
|
||||
'GPT_PROXY': data['proxy'],
|
||||
'GPT_MODEL': data['model'],
|
||||
'CHAT_AI_TYPE': settings.CHAT_AI_TYPE,
|
||||
'GPT_BASE_URL': data.get('url'),
|
||||
'GPT_API_KEY': data.get('api_key'),
|
||||
'GPT_PROXY': data.get('proxy'),
|
||||
'CHAT_AI_PROVIDERS': settings.CHAT_AI_PROVIDERS,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
#
|
||||
from .applet import *
|
||||
from .applet_host import *
|
||||
from .chat import *
|
||||
from .command import *
|
||||
from .endpoint import *
|
||||
from .loki import *
|
||||
from .session import *
|
||||
from .sharing import *
|
||||
from .storage import *
|
||||
@@ -11,4 +13,3 @@ from .task import *
|
||||
from .terminal import *
|
||||
from .virtualapp import *
|
||||
from .virtualapp_provider import *
|
||||
from .loki import *
|
||||
|
||||
28
apps/terminal/serializers/chat.py
Normal file
28
apps/terminal/serializers/chat.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.serializers import CommonBulkModelSerializer
|
||||
from terminal.models import Chat
|
||||
|
||||
__all__ = ['ChatSerializer']
|
||||
|
||||
|
||||
class ChatSerializer(CommonBulkModelSerializer):
|
||||
created_at = serializers.SerializerMethodField()
|
||||
updated_at = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = Chat
|
||||
fields_mini = ['id', 'title', 'created_at', 'updated_at']
|
||||
fields = fields_mini + [
|
||||
'chat', 'meta', 'pinned', 'archived',
|
||||
'share_id', 'folder_id',
|
||||
'user_id', 'session_info'
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_created_at(obj):
|
||||
return int(obj.date_created.timestamp())
|
||||
|
||||
@staticmethod
|
||||
def get_updated_at(obj):
|
||||
return int(obj.date_updated.timestamp())
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user