Compare commits

..

1 Commits

Author SHA1 Message Date
feng
d1f3aa21be perf: Top session asset user cache 2025-12-09 17:58:58 +08:00
45 changed files with 1123 additions and 760 deletions

View File

@@ -150,7 +150,6 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
auto_config = serializers.DictField(read_only=True, label=_('Auto info')) auto_config = serializers.DictField(read_only=True, label=_('Auto info'))
platform = ObjectRelatedField(queryset=Platform.objects, required=True, label=_('Platform'), platform = ObjectRelatedField(queryset=Platform.objects, required=True, label=_('Platform'),
attrs=('id', 'name', 'type')) attrs=('id', 'name', 'type'))
spec_info = serializers.DictField(read_only=True, label=_('Spec info'))
accounts_amount = serializers.IntegerField(read_only=True, label=_('Accounts amount')) accounts_amount = serializers.IntegerField(read_only=True, label=_('Accounts amount'))
_accounts = None _accounts = None
@@ -165,8 +164,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
'directory_services', 'directory_services',
] ]
read_only_fields = [ read_only_fields = [
'accounts_amount', 'category', 'type', 'connectivity', 'accounts_amount', 'category', 'type', 'connectivity', 'auto_config',
'auto_config', 'spec_info',
'date_verified', 'created_by', 'date_created', 'date_updated', 'date_verified', 'created_by', 'date_created', 'date_updated',
] ]
fields = fields_small + fields_fk + fields_m2m + read_only_fields fields = fields_small + fields_fk + fields_m2m + read_only_fields
@@ -188,7 +186,6 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._init_field_choices() self._init_field_choices()
self._extract_accounts() self._extract_accounts()
self._set_platform()
def _extract_accounts(self): def _extract_accounts(self):
if not getattr(self, 'initial_data', None): if not getattr(self, 'initial_data', None):
@@ -220,21 +217,6 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
protocols_data = [{'name': p.name, 'port': p.port} for p in protocols] protocols_data = [{'name': p.name, 'port': p.port} for p in protocols]
self.initial_data['protocols'] = protocols_data self.initial_data['protocols'] = protocols_data
def _set_platform(self):
if not hasattr(self, 'initial_data'):
return
platform_id = self.initial_data.get('platform')
if not platform_id:
return
if isinstance(platform_id, int) or str(platform_id).isdigit() or not isinstance(platform_id, str):
return
platform = Platform.objects.filter(name=platform_id).first()
if not platform:
return
self.initial_data['platform'] = platform.id
def _init_field_choices(self): def _init_field_choices(self):
request = self.context.get('request') request = self.context.get('request')
if not request: if not request:
@@ -249,19 +231,6 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
return return
field_type.choices = AllTypes.filter_choices(category) field_type.choices = AllTypes.filter_choices(category)
@staticmethod
def get_spec_info(obj):
return {}
def get_auto_config(self, obj):
return obj.auto_config()
def get_gathered_info(self, obj):
return obj.gathered_info()
def get_accounts_amount(self, obj):
return obj.accounts_amount()
@classmethod @classmethod
def setup_eager_loading(cls, queryset): def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """ """ Perform necessary eager loading of data. """
@@ -296,10 +265,8 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
if not platform_id and self.instance: if not platform_id and self.instance:
platform = self.instance.platform platform = self.instance.platform
elif isinstance(platform_id, int):
platform = Platform.objects.filter(id=platform_id).first()
else: else:
platform = Platform.objects.filter(name=platform_id).first() platform = Platform.objects.filter(id=platform_id).first()
if not platform: if not platform:
raise serializers.ValidationError({'platform': _("Platform not exist")}) raise serializers.ValidationError({'platform': _("Platform not exist")})
@@ -330,7 +297,6 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
def is_valid(self, raise_exception=False): def is_valid(self, raise_exception=False):
self._set_protocols_default() self._set_protocols_default()
self._set_platform()
return super().is_valid(raise_exception=raise_exception) return super().is_valid(raise_exception=raise_exception)
def validate_protocols(self, protocols_data): def validate_protocols(self, protocols_data):
@@ -456,7 +422,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
class DetailMixin(serializers.Serializer): class DetailMixin(serializers.Serializer):
accounts = AssetAccountSerializer(many=True, required=False, label=_('Accounts')) accounts = AssetAccountSerializer(many=True, required=False, label=_('Accounts'))
spec_info = MethodSerializer(label=_('Spec info'), read_only=True, required=False) spec_info = MethodSerializer(label=_('Spec info'), read_only=True)
gathered_info = MethodSerializer(label=_('Gathered info'), read_only=True) gathered_info = MethodSerializer(label=_('Gathered info'), read_only=True)
auto_config = serializers.DictField(read_only=True, label=_('Auto info')) auto_config = serializers.DictField(read_only=True, label=_('Auto info'))

View File

@@ -67,7 +67,6 @@ class UserLoginMFAView(mixins.AuthMixin, FormView):
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
user = self.get_user_from_session() user = self.get_user_from_session()
mfa_context = self.get_user_mfa_context(user) mfa_context = self.get_user_mfa_context(user)
print(mfa_context)
kwargs.update(mfa_context) kwargs.update(mfa_context)
return kwargs return kwargs

View File

@@ -1,3 +1,4 @@
from .aggregate import *
from .dashboard import IndexApi from .dashboard import IndexApi
from .health import PrometheusMetricsApi, HealthCheckView from .health import PrometheusMetricsApi, HealthCheckView
from .search import GlobalSearchView from .search import GlobalSearchView

View File

@@ -0,0 +1,9 @@
from .detail import ResourceDetailApi
from .list import ResourceListApi
from .supported import ResourceTypeListApi
__all__ = [
'ResourceListApi',
'ResourceDetailApi',
'ResourceTypeListApi',
]

View File

@@ -0,0 +1,57 @@
list_params = [
{
"name": "search",
"in": "query",
"description": "A search term.",
"required": False,
"type": "string"
},
{
"name": "order",
"in": "query",
"description": "Which field to use when ordering the results.",
"required": False,
"type": "string"
},
{
"name": "limit",
"in": "query",
"description": "Number of results to return per page. Default is 10.",
"required": False,
"type": "integer"
},
{
"name": "offset",
"in": "query",
"description": "The initial index from which to return the results.",
"required": False,
"type": "integer"
},
]
common_params = [
{
"name": "resource",
"in": "path",
"description": """Resource to query, e.g. users, assets, permissions, acls, user-groups, policies, nodes, hosts,
devices, clouds, webs, databases,
gpts, ds, customs, platforms, zones, gateways, protocol-settings, labels, virtual-accounts,
gathered-accounts, account-templates, account-template-secrets, account-backups, account-backup-executions,
change-secret-automations, change-secret-executions, change-secret-records, gather-account-automations,
gather-account-executions, push-account-automations, push-account-executions, push-account-records,
check-account-automations, check-account-executions, account-risks, integration-apps, asset-permissions,
zones, gateways, virtual-accounts, gathered-accounts, account-templates, account-template-secrets,,
GET /api/v1/resources/ to get full supported resource.
""",
"required": True,
"type": "string"
},
{
"name": "X-JMS-ORG",
"in": "header",
"description": "The organization ID to use for the request. Organization is the namespace for resources, if not set, use default org",
"required": False,
"type": "string"
}
]

View File

@@ -0,0 +1,75 @@
# views.py
from drf_spectacular.utils import extend_schema, OpenApiParameter
from rest_framework.permissions import IsAuthenticated
from rest_framework.views import APIView
from .const import common_params
from .proxy import ProxyMixin
from .utils import param_dic_to_param
one_param = [
{
'name': 'id',
'in': 'path',
'required': True,
'description': 'Resource ID',
'type': 'string',
}
]
object_params = [
param_dic_to_param(d)
for d in common_params + one_param
]
class ResourceDetailApi(ProxyMixin, APIView):
permission_classes = [IsAuthenticated]
@extend_schema(
operation_id="get_resource_detail",
summary="Get resource detail",
parameters=object_params,
description="""
Get resource detail.
{resource} is the resource name, GET /api/v1/resources/ to get full supported resource.
""",
)
def get(self, request, resource, pk=None):
return self._proxy(request, resource, pk=pk, action='retrieve')
@extend_schema(
operation_id="delete_resource",
summary="Delete the resource",
parameters=object_params,
description="Delete the resource, and can not be restored",
)
def delete(self, request, resource, pk=None):
return self._proxy(request, resource, pk, action='destroy')
@extend_schema(
operation_id="update_resource",
summary="Update the resource property",
parameters=object_params,
description="""
Update the resource property, all property will be update,
{resource} is the resource name, GET /api/v1/resources/ to get full supported resource.
OPTION /api/v1/resources/{resource}/{id}/?action=put to get field type and helptext.
""",
)
def put(self, request, resource, pk=None):
return self._proxy(request, resource, pk, action='update')
@extend_schema(
operation_id="partial_update_resource",
summary="Update the resource property",
parameters=object_params,
description="""
Partial update the resource property, only request property will be update,
OPTION /api/v1/resources/{resource}/{id}/?action=patch to get field type and helptext.
""",
)
def patch(self, request, resource, pk=None):
return self._proxy(request, resource, pk, action='partial_update')

View File

@@ -0,0 +1,87 @@
# views.py
from drf_spectacular.utils import extend_schema
from rest_framework.routers import DefaultRouter
from rest_framework.views import APIView
from .const import list_params, common_params
from .proxy import ProxyMixin
from .utils import param_dic_to_param
router = DefaultRouter()
BASE_URL = "http://localhost:8080"
list_params = [
param_dic_to_param(d)
for d in list_params + common_params
]
create_params = [
param_dic_to_param(d)
for d in common_params
]
list_schema = {
"required": [
"count",
"results"
],
"type": "object",
"properties": {
"count": {
"type": "integer"
},
"next": {
"type": "string",
"format": "uri",
"x-nullable": True
},
"previous": {
"type": "string",
"format": "uri",
"x-nullable": True
},
"results": {
"type": "array",
"items": {
}
}
}
}
from drf_spectacular.openapi import OpenApiResponse, OpenApiExample
class ResourceListApi(ProxyMixin, APIView):
@extend_schema(
operation_id="get_resource_list",
summary="Get resource list",
parameters=list_params,
responses={200: OpenApiResponse(description="Resource list response")},
description="""
Get resource list, you should set the resource name in the url.
OPTIONS /api/v1/resources/{resource}/?action=get to get every type resource's field type and help text.
""",
)
# ↓↓↓ Swagger 自动文档 ↓↓↓
def get(self, request, resource):
return self._proxy(request, resource)
@extend_schema(
operation_id="create_resource_by_type",
summary="Create resource",
parameters=create_params,
description="""
Create resource,
OPTIONS /api/v1/resources/{resource}/?action=post to get every resource type field type and helptext, and
you will know how to create it.
""",
)
def post(self, request, resource, pk=None):
if not resource:
resource = request.data.pop('resource', '')
return self._proxy(request, resource, pk, action='create')
def options(self, request, resource, pk=None):
return self._proxy(request, resource, pk, action='metadata')

View File

@@ -0,0 +1,75 @@
# views.py
from urllib.parse import urlencode
import requests
from rest_framework.exceptions import NotFound, APIException
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.routers import DefaultRouter
from rest_framework.views import APIView
from .utils import get_full_resource_map
router = DefaultRouter()
BASE_URL = "http://localhost:8080"
class ProxyMixin(APIView):
"""
通用资源代理 API支持动态路径、自动文档生成
"""
permission_classes = [IsAuthenticated]
def _build_url(self, resource_name: str, pk: str = None, query_params=None):
resource_map = get_full_resource_map()
resource = resource_map.get(resource_name)
if not resource:
raise NotFound(f"Unknown resource: {resource_name}")
base_path = resource['path']
if pk:
base_path += f"{pk}/"
if query_params:
base_path += f"?{urlencode(query_params)}"
return f"{BASE_URL}{base_path}"
def _proxy(self, request, resource: str, pk: str = None, action='list'):
method = request.method.lower()
if method not in ['get', 'post', 'put', 'patch', 'delete', 'options']:
raise APIException("Unsupported method")
if not resource or resource == '{resource}':
if request.data:
resource = request.data.get('resource')
query_params = request.query_params.dict()
if action == 'list':
query_params['limit'] = 10
url = self._build_url(resource, pk, query_params)
headers = {k: v for k, v in request.headers.items() if k.lower() != 'host'}
cookies = request.COOKIES
body = request.body if method in ['post', 'put', 'patch'] else None
try:
resp = requests.request(
method=method,
url=url,
headers=headers,
cookies=cookies,
data=body,
timeout=10,
)
content_type = resp.headers.get('Content-Type', '')
if 'application/json' in content_type:
data = resp.json()
else:
data = resp.text # 或者 bytesresp.content
return Response(data=data, status=resp.status_code)
except requests.RequestException as e:
raise APIException(f"Proxy request failed: {str(e)}")

View File

@@ -0,0 +1,45 @@
# views.py
from drf_spectacular.utils import extend_schema
from rest_framework import serializers
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.routers import DefaultRouter
from rest_framework.views import APIView
from .utils import get_full_resource_map
router = DefaultRouter()
BASE_URL = "http://localhost:8080"
class ResourceTypeResourceSerializer(serializers.Serializer):
name = serializers.CharField()
path = serializers.CharField()
app = serializers.CharField()
verbose_name = serializers.CharField()
description = serializers.CharField()
class ResourceTypeListApi(APIView):
permission_classes = [IsAuthenticated]
@extend_schema(
operation_id="get_supported_resources",
summary="Get-all-support-resources",
description="Get all support resources, name, path, verbose_name description",
responses={200: ResourceTypeResourceSerializer(many=True)}, # Specify the response serializer
)
def get(self, request):
result = []
resource_map = get_full_resource_map()
for name, desc in resource_map.items():
desc = resource_map.get(name, {})
resource = {
"name": name,
**desc,
"path": f'/api/v1/resources/{name}/',
}
result.append(resource)
return Response(result)

View File

@@ -0,0 +1,128 @@
# views.py
import re
from functools import lru_cache
from typing import Dict
from django.urls import URLPattern
from django.urls import URLResolver
from drf_spectacular.utils import OpenApiParameter
from rest_framework.routers import DefaultRouter
router = DefaultRouter()
BASE_URL = "http://localhost:8080"
def clean_path(path: str) -> str:
"""
清理掉 DRF 自动生成的正则格式内容,让其变成普通 RESTful URL path。
"""
# 去掉格式后缀匹配: \.(?P<format>xxx)
path = re.sub(r'\\\.\(\?P<format>[^)]+\)', '', path)
# 去掉括号式格式匹配
path = re.sub(r'\(\?P<format>[^)]+\)', '', path)
# 移除 DRF 中正则参数的部分 (?P<param>pattern)
path = re.sub(r'\(\?P<\w+>[^)]+\)', '{param}', path)
# 如果有多个括号包裹的正则(比如前缀路径),去掉可选部分包装
path = re.sub(r'\(\(([^)]+)\)\?\)', r'\1', path) # ((...))? => ...
# 去掉中间和两边的 ^ 和 $
path = path.replace('^', '').replace('$', '')
# 去掉尾部 ?/
path = re.sub(r'\?/?$', '', path)
# 去掉反斜杠
path = path.replace('\\', '')
# 替换多重斜杠
path = re.sub(r'/+', '/', path)
# 添加开头斜杠,移除多余空格
path = path.strip()
if not path.startswith('/'):
path = '/' + path
if not path.endswith('/'):
path += '/'
return path
def extract_resource_paths(urlpatterns, prefix='/api/v1/') -> Dict[str, Dict[str, str]]:
resource_map = {}
for pattern in urlpatterns:
if isinstance(pattern, URLResolver):
nested_prefix = prefix + str(pattern.pattern)
resource_map.update(extract_resource_paths(pattern.url_patterns, nested_prefix))
elif isinstance(pattern, URLPattern):
callback = pattern.callback
actions = getattr(callback, 'actions', {})
if not actions:
continue
if 'get' in actions and actions['get'] == 'list':
path = clean_path(prefix + str(pattern.pattern))
# 尝试获取资源名称
name = pattern.name
if name and name.endswith('-list'):
resource = name[:-5]
else:
resource = path.strip('/').split('/')[-1]
# 不强行加 s资源名保持原状即可
resource = resource if resource.endswith('s') else resource + 's'
# 获取 View 类和 model 的 verbose_name
view_cls = getattr(callback, 'cls', None)
model = None
if view_cls:
queryset = getattr(view_cls, 'queryset', None)
if queryset is not None:
model = getattr(queryset, 'model', None)
else:
# 有些 View 用 get_queryset()
try:
instance = view_cls()
qs = instance.get_queryset()
model = getattr(qs, 'model', None)
except Exception:
pass
if not model:
continue
app = str(getattr(model._meta, 'app_label', ''))
verbose_name = str(getattr(model._meta, 'verbose_name', ''))
resource_map[resource] = {
'path': path,
'app': app,
'verbose_name': verbose_name,
'description': model.__doc__.__str__()
}
print("Extracted resource paths:", list(resource_map.keys()))
return resource_map
def param_dic_to_param(d):
return OpenApiParameter(
name=d['name'], location=d['in'],
description=d['description'], type=d['type'], required=d.get('required', False)
)
@lru_cache()
def get_full_resource_map():
from apps.jumpserver.urls import resource_api
resource_map = extract_resource_paths(resource_api)
print("Building URL for resource:", resource_map)
return resource_map

View File

@@ -1,5 +1,6 @@
from collections import defaultdict from collections import defaultdict
from django.core.cache import cache
from django.db.models import Count, Max, F, CharField, Q from django.db.models import Count, Max, F, CharField, Q
from django.db.models.functions import Cast from django.db.models.functions import Cast
from django.http.response import JsonResponse from django.http.response import JsonResponse
@@ -144,6 +145,7 @@ class DateTimeMixin:
class DatesLoginMetricMixin: class DatesLoginMetricMixin:
days: int
dates_list: list dates_list: list
date_start_end: tuple date_start_end: tuple
command_type_queryset_list: list command_type_queryset_list: list
@@ -155,6 +157,8 @@ class DatesLoginMetricMixin:
operate_logs_queryset: OperateLog.objects operate_logs_queryset: OperateLog.objects
password_change_logs_queryset: PasswordChangeLog.objects password_change_logs_queryset: PasswordChangeLog.objects
CACHE_TIMEOUT = 60
@lazyproperty @lazyproperty
def get_type_to_assets(self): def get_type_to_assets(self):
result = Asset.objects.annotate(type=F('platform__type')). \ result = Asset.objects.annotate(type=F('platform__type')). \
@@ -214,19 +218,34 @@ class DatesLoginMetricMixin:
return date_metrics_dict.get('id', []) return date_metrics_dict.get('id', [])
def get_dates_login_times_assets(self): def get_dates_login_times_assets(self):
cache_key = f"stats:top10_assets:{self.days}"
data = cache.get(cache_key)
if data is not None:
return data
assets = self.sessions_queryset.values("asset") \ assets = self.sessions_queryset.values("asset") \
.annotate(total=Count("asset")) \ .annotate(total=Count("asset")) \
.annotate(last=Cast(Max("date_start"), output_field=CharField())) \ .annotate(last=Cast(Max("date_start"), output_field=CharField())) \
.order_by("-total") .order_by("-total")
return list(assets[:10])
result = list(assets[:10])
cache.set(cache_key, result, self.CACHE_TIMEOUT)
return result
def get_dates_login_times_users(self): def get_dates_login_times_users(self):
cache_key = f"stats:top10_users:{self.days}"
data = cache.get(cache_key)
if data is not None:
return data
users = self.sessions_queryset.values("user_id") \ users = self.sessions_queryset.values("user_id") \
.annotate(total=Count("user_id")) \ .annotate(total=Count("user_id")) \
.annotate(user=Max('user')) \ .annotate(user=Max('user')) \
.annotate(last=Cast(Max("date_start"), output_field=CharField())) \ .annotate(last=Cast(Max("date_start"), output_field=CharField())) \
.order_by("-total") .order_by("-total")
return list(users[:10]) result = list(users[:10])
cache.set(cache_key, result, self.CACHE_TIMEOUT)
return result
def get_dates_login_record_sessions(self): def get_dates_login_record_sessions(self):
sessions = self.sessions_queryset.order_by('-date_start') sessions = self.sessions_queryset.order_by('-date_start')

View File

@@ -701,7 +701,15 @@ class Config(dict):
'CHAT_AI_ENABLED': False, 'CHAT_AI_ENABLED': False,
'CHAT_AI_METHOD': 'api', 'CHAT_AI_METHOD': 'api',
'CHAT_AI_EMBED_URL': '', 'CHAT_AI_EMBED_URL': '',
'CHAT_AI_PROVIDERS': [], 'CHAT_AI_TYPE': 'gpt',
'GPT_BASE_URL': '',
'GPT_API_KEY': '',
'GPT_PROXY': '',
'GPT_MODEL': 'gpt-4o-mini',
'DEEPSEEK_BASE_URL': '',
'DEEPSEEK_API_KEY': '',
'DEEPSEEK_PROXY': '',
'DEEPSEEK_MODEL': 'deepseek-chat',
'VIRTUAL_APP_ENABLED': False, 'VIRTUAL_APP_ENABLED': False,
'FILE_UPLOAD_SIZE_LIMIT_MB': 200, 'FILE_UPLOAD_SIZE_LIMIT_MB': 200,

View File

@@ -5,7 +5,6 @@ from rest_framework.pagination import LimitOffsetPagination
class MaxLimitOffsetPagination(LimitOffsetPagination): class MaxLimitOffsetPagination(LimitOffsetPagination):
max_limit = settings.MAX_PAGE_SIZE max_limit = settings.MAX_PAGE_SIZE
default_limit = settings.DEFAULT_PAGE_SIZE
def get_count(self, queryset): def get_count(self, queryset):
try: try:

View File

@@ -241,7 +241,15 @@ ASSET_SIZE = 'small'
CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED
CHAT_AI_METHOD = CONFIG.CHAT_AI_METHOD CHAT_AI_METHOD = CONFIG.CHAT_AI_METHOD
CHAT_AI_EMBED_URL = CONFIG.CHAT_AI_EMBED_URL CHAT_AI_EMBED_URL = CONFIG.CHAT_AI_EMBED_URL
CHAT_AI_DEFAULT_PROVIDER = CONFIG.CHAT_AI_DEFAULT_PROVIDER CHAT_AI_TYPE = CONFIG.CHAT_AI_TYPE
GPT_BASE_URL = CONFIG.GPT_BASE_URL
GPT_API_KEY = CONFIG.GPT_API_KEY
GPT_PROXY = CONFIG.GPT_PROXY
GPT_MODEL = CONFIG.GPT_MODEL
DEEPSEEK_BASE_URL = CONFIG.DEEPSEEK_BASE_URL
DEEPSEEK_API_KEY = CONFIG.DEEPSEEK_API_KEY
DEEPSEEK_PROXY = CONFIG.DEEPSEEK_PROXY
DEEPSEEK_MODEL = CONFIG.DEEPSEEK_MODEL
VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED
@@ -261,5 +269,3 @@ TOOL_USER_ENABLED = CONFIG.TOOL_USER_ENABLED
SUGGESTION_LIMIT = CONFIG.SUGGESTION_LIMIT SUGGESTION_LIMIT = CONFIG.SUGGESTION_LIMIT
MCP_ENABLED = CONFIG.MCP_ENABLED MCP_ENABLED = CONFIG.MCP_ENABLED
CHAT_AI_PROVIDERS = CONFIG.CHAT_AI_PROVIDERS

View File

@@ -85,7 +85,6 @@ SPECTACULAR_SETTINGS = {
'jumpserver.views.schema.LabeledChoiceFieldExtension', 'jumpserver.views.schema.LabeledChoiceFieldExtension',
'jumpserver.views.schema.BitChoicesFieldExtension', 'jumpserver.views.schema.BitChoicesFieldExtension',
'jumpserver.views.schema.LabelRelatedFieldExtension', 'jumpserver.views.schema.LabelRelatedFieldExtension',
'jumpserver.views.schema.DateTimeFieldExtension',
], ],
'SECURITY': [{'Bearer': []}], 'SECURITY': [{'Bearer': []}],
} }

View File

@@ -37,6 +37,12 @@ api_v1 = resource_api + [
path('prometheus/metrics/', api.PrometheusMetricsApi.as_view()), path('prometheus/metrics/', api.PrometheusMetricsApi.as_view()),
path('search/', api.GlobalSearchView.as_view()), path('search/', api.GlobalSearchView.as_view()),
] ]
if settings.MCP_ENABLED:
api_v1.extend([
path('resources/', api.ResourceTypeListApi.as_view(), name='resource-list'),
path('resources/<str:resource>/', api.ResourceListApi.as_view()),
path('resources/<str:resource>/<str:pk>/', api.ResourceDetailApi.as_view()),
])
app_view_patterns = [ app_view_patterns = [
path('auth/', include('authentication.urls.view_urls'), name='auth'), path('auth/', include('authentication.urls.view_urls'), name='auth'),

View File

@@ -0,0 +1,421 @@
import re
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.generators import SchemaGenerator
class CustomSchemaGenerator(SchemaGenerator):
from_mcp = False
def get_schema(self, request=None, public=False):
self.from_mcp = request.query_params.get('mcp') or request.path.endswith('swagger.json')
return super().get_schema(request, public)
class CustomAutoSchema(AutoSchema):
def __init__(self, *args, **kwargs):
self.from_mcp = kwargs.get('from_mcp', False)
super().__init__(*args, **kwargs)
def map_parsers(self):
return ['application/json']
def map_renderers(self, *args, **kwargs):
return ['application/json']
def get_tags(self):
operation_keys = self._tokenize_path()
if len(operation_keys) == 1:
return []
tags = ['_'.join(operation_keys[:2])]
return tags
def get_operation(self, path, *args, **kwargs):
if path.endswith('render-to-json/'):
return None
# if not path.startswith('/api/v1/users'):
# return None
operation = super().get_operation(path, *args, **kwargs)
if not operation:
return operation
if not operation.get('summary', ''):
operation['summary'] = operation.get('operationId')
return operation
def get_operation_id(self):
tokenized_path = self._tokenize_path()
# replace dashes as they can be problematic later in code generation
tokenized_path = [t.replace('-', '_') for t in tokenized_path]
action = ''
if hasattr(self.view, 'action'):
action = self.view.action
if not action:
if self.method == 'GET' and self._is_list_view():
action = 'list'
else:
action = self.method_mapping[self.method.lower()]
if action == "bulk_destroy":
action = "bulk_delete"
if not tokenized_path:
tokenized_path.append('root')
if re.search(r'<drf_format_suffix\w*:\w+>', self.path_regex):
tokenized_path.append('formatted')
return '_'.join(tokenized_path + [action])
def get_filter_parameters(self):
if not self.should_filter():
return []
fields = []
if hasattr(self.view, 'get_filter_backends'):
backends = self.view.get_filter_backends()
elif hasattr(self.view, 'filter_backends'):
backends = self.view.filter_backends
else:
backends = []
for filter_backend in backends:
fields += self.probe_inspectors(
self.filter_inspectors, 'get_filter_parameters', filter_backend()
) or []
return fields
def get_auth(self):
return [{'Bearer': []}]
def get_operation_security(self):
"""
重写操作安全配置,统一使用 Bearer token
"""
return [{'Bearer': []}]
def get_components_security_schemes(self):
"""
重写安全方案定义,避免认证类解析错误
"""
return {
'Bearer': {
'type': 'http',
'scheme': 'bearer',
'bearerFormat': 'JWT',
'description': 'JWT token for API authentication'
}
}
@staticmethod
def exclude_some_paths(path):
# 这里可以对 paths 进行处理
excludes = [
'/report/', '/render-to-json/', '/suggestions/',
'executions', 'automations', 'change-secret-records',
'change-secret-dashboard', '/copy-to-assets/',
'/move-to-assets/', 'dashboard', 'index', 'countries',
'/resources/cache/', 'profile/mfa', 'profile/password',
'profile/permissions', 'prometheus', 'constraints'
]
for p in excludes:
if path.find(p) >= 0:
return True
return False
def exclude_some_app_model(self, path):
parts = path.split('/')
if len(parts) < 5:
return False
apps = []
if self.from_mcp:
apps = [
'ops', 'tickets', 'authentication',
'settings', 'xpack', 'terminal', 'rbac',
'notifications', 'promethues', 'acls'
]
app_name = parts[3]
if app_name in apps:
return True
models = []
model = parts[4]
if self.from_mcp:
models = [
'users', 'user-groups', 'users-groups-relations', 'assets', 'hosts', 'devices', 'databases',
'webs', 'clouds', 'gpts', 'ds', 'customs', 'platforms', 'nodes', 'zones', 'gateways',
'protocol-settings', 'labels', 'virtual-accounts', 'gathered-accounts', 'account-templates',
'account-template-secrets', 'account-backups', 'account-backup-executions',
'change-secret-automations', 'change-secret-executions', 'change-secret-records',
'gather-account-automations', 'gather-account-executions', 'push-account-automations',
'push-account-executions', 'push-account-records', 'check-account-automations',
'check-account-executions', 'account-risks', 'integration-apps', 'asset-permissions',
'asset-permissions-users-relations', 'asset-permissions-user-groups-relations',
'asset-permissions-assets-relations', 'asset-permissions-nodes-relations', 'terminal-status',
'terminals', 'tasks', 'status', 'replay-storages', 'command-storages', 'session-sharing-records',
'endpoints', 'endpoint-rules', 'applets', 'applet-hosts', 'applet-publications',
'applet-host-deployments', 'virtual-apps', 'app-providers', 'virtual-app-publications',
'celery-period-tasks', 'task-executions', 'adhocs', 'playbooks', 'variables', 'ftp-logs',
'login-logs', 'operate-logs', 'password-change-logs', 'job-logs', 'jobs', 'user-sessions',
'service-access-logs', 'chatai-prompts', 'super-connection-tokens', 'flows',
'apply-assets', 'apply-nodes', 'login-acls', 'login-asset-acls', 'command-filter-acls',
'command-groups', 'connect-method-acls', 'system-msg-subscriptions', 'roles', 'role-bindings',
'system-roles', 'system-role-bindings', 'org-roles', 'org-role-bindings', 'content-types',
'labeled-resources', 'account-backup-plans', 'account-check-engines', 'account-secrets',
'change-secret', 'integration-applications', 'push-account', 'directories', 'connection-token',
'groups', 'accounts', 'resource-types', 'favorite-assets', 'activities', 'platform-automation-methods',
]
if model in models:
return True
return False
def is_excluded(self):
if self.exclude_some_paths(self.path):
return True
if self.exclude_some_app_model(self.path):
return True
return False
def get_operation(self, path, *args, **kwargs):
operation = super().get_operation(path, *args, **kwargs)
if not operation:
return operation
operation_id = operation.get('operationId')
if 'bulk' in operation_id:
return None
if not operation.get('summary', ''):
operation['summary'] = operation.get('operationId')
exclude_operations = [
'orgs_orgs_read', 'orgs_orgs_update', 'orgs_orgs_delete',
'orgs_orgs_create', 'orgs_orgs_partial_update',
]
if operation_id in exclude_operations:
return None
return operation
# 添加自定义字段的 OpenAPI 扩展
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.plumbing import build_basic_type
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField, BitChoicesField
class ObjectRelatedFieldExtension(OpenApiSerializerFieldExtension):
"""
为 ObjectRelatedField 提供 OpenAPI schema
"""
target_class = ObjectRelatedField
def map_serializer_field(self, auto_schema, direction):
field = self.target
# 获取字段的基本信息
field_type = 'array' if field.many else 'object'
if field_type == 'array':
# 如果是多对多关系
return {
'type': 'array',
'items': self._get_openapi_item_schema(field),
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
else:
# 如果是一对一关系
return {
'type': 'object',
'properties': self._get_openapi_properties_schema(field),
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
def _get_openapi_item_schema(self, field):
"""
获取数组项的 OpenAPI schema
"""
return self._get_openapi_object_schema(field)
def _get_openapi_object_schema(self, field):
"""
获取对象的 OpenAPI schema
"""
properties = {}
# 动态分析 attrs 中的属性类型
for attr in field.attrs:
# 尝试从 queryset 的 model 中获取字段信息
field_type = self._infer_field_type(field, attr)
properties[attr] = {
'type': field_type,
'description': f'{attr} field'
}
return {
'type': 'object',
'properties': properties,
'required': ['id'] if 'id' in field.attrs else []
}
def _infer_field_type(self, field, attr_name):
"""
智能推断字段类型
"""
try:
# 如果有 queryset尝试从 model 中获取字段信息
if hasattr(field, 'queryset') and field.queryset is not None:
model = field.queryset.model
if hasattr(model, '_meta') and hasattr(model._meta, 'fields'):
model_field = model._meta.get_field(attr_name)
if model_field:
return self._map_django_field_type(model_field)
except Exception:
pass
# 如果没有 queryset 或无法获取字段信息,使用启发式规则
return self._heuristic_field_type(attr_name)
def _map_django_field_type(self, model_field):
"""
将 Django 字段类型映射到 OpenAPI 类型
"""
field_type = type(model_field).__name__
# 整数类型
if 'Integer' in field_type or 'BigInteger' in field_type or 'SmallInteger' in field_type:
return 'integer'
# 浮点数类型
elif 'Float' in field_type or 'Decimal' in field_type:
return 'number'
# 布尔类型
elif 'Boolean' in field_type:
return 'boolean'
# 日期时间类型
elif 'DateTime' in field_type or 'Date' in field_type or 'Time' in field_type:
return 'string'
# 文件类型
elif 'File' in field_type or 'Image' in field_type:
return 'string'
# 其他类型默认为字符串
else:
return 'string'
def _heuristic_field_type(self, attr_name):
"""
启发式推断字段类型
"""
# 基于属性名的启发式规则
if attr_name in ['is_active', 'enabled', 'visible'] or attr_name.startswith('is_'):
return 'boolean'
elif attr_name in ['count', 'number', 'size', 'amount']:
return 'integer'
elif attr_name in ['price', 'rate', 'percentage']:
return 'number'
else:
# 默认返回字符串类型
return 'string'
def _get_openapi_properties_schema(self, field):
"""
获取对象属性的 OpenAPI schema
"""
return self._get_openapi_object_schema(field)['properties']
class LabeledChoiceFieldExtension(OpenApiSerializerFieldExtension):
"""
为 LabeledChoiceField 提供 OpenAPI schema
"""
target_class = LabeledChoiceField
def map_serializer_field(self, auto_schema, direction):
field = self.target
if getattr(field, 'many', False):
return {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'value': {'type': 'string'},
'label': {'type': 'string'}
}
},
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
else:
return {
'type': 'object',
'properties': {
'value': {'type': 'string'},
'label': {'type': 'string'}
},
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
class BitChoicesFieldExtension(OpenApiSerializerFieldExtension):
"""
为 BitChoicesField 提供 OpenAPI schema
"""
target_class = BitChoicesField
def map_serializer_field(self, auto_schema, direction):
field = self.target
return {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'value': {'type': 'string'},
'label': {'type': 'string'}
}
},
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
class LabelRelatedFieldExtension(OpenApiSerializerFieldExtension):
"""
为 LabelRelatedField 提供 OpenAPI schema
"""
target_class = 'common.serializers.fields.LabelRelatedField'
def map_serializer_field(self, auto_schema, direction):
field = self.target
# LabelRelatedField 返回一个包含 id, name, value, color 的对象
return {
'type': 'object',
'properties': {
'id': {
'type': 'string',
'description': 'Label ID'
},
'name': {
'type': 'string',
'description': 'Label name'
},
'value': {
'type': 'string',
'description': 'Label value'
},
'color': {
'type': 'string',
'description': 'Label color'
}
},
'required': ['id', 'name', 'value'],
'description': getattr(field, 'help_text', 'Label information'),
'title': getattr(field, 'label', 'Label'),
}

View File

@@ -1,2 +0,0 @@
from .extension import *
from .schema import *

View File

@@ -1,263 +0,0 @@
# 添加自定义字段的 OpenAPI 扩展
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.plumbing import build_basic_type
from rest_framework import serializers
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField, BitChoicesField
__all__ = [
'ObjectRelatedFieldExtension', 'LabeledChoiceFieldExtension',
'BitChoicesFieldExtension', 'LabelRelatedFieldExtension',
'DateTimeFieldExtension'
]
class ObjectRelatedFieldExtension(OpenApiSerializerFieldExtension):
"""
为 ObjectRelatedField 提供 OpenAPI schema
"""
target_class = ObjectRelatedField
def map_serializer_field(self, auto_schema, direction):
field = self.target
# 获取字段的基本信息
field_type = 'array' if field.many else 'object'
if field_type == 'array':
# 如果是多对多关系
return {
'type': 'array',
'items': self._get_openapi_item_schema(field),
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
else:
# 如果是一对一关系
return {
'type': 'object',
'properties': self._get_openapi_properties_schema(field),
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
def _get_openapi_item_schema(self, field):
"""
获取数组项的 OpenAPI schema
"""
return self._get_openapi_object_schema(field)
def _get_openapi_object_schema(self, field):
"""
获取对象的 OpenAPI schema
"""
properties = {}
# 动态分析 attrs 中的属性类型
for attr in field.attrs:
# 尝试从 queryset 的 model 中获取字段信息
field_type = self._infer_field_type(field, attr)
properties[attr] = {
'type': field_type,
'description': f'{attr} field'
}
return {
'type': 'object',
'properties': properties,
'required': ['id'] if 'id' in field.attrs else []
}
def _infer_field_type(self, field, attr_name):
"""
智能推断字段类型
"""
try:
# 如果有 queryset尝试从 model 中获取字段信息
if hasattr(field, 'queryset') and field.queryset is not None:
model = field.queryset.model
if hasattr(model, '_meta') and hasattr(model._meta, 'fields'):
model_field = model._meta.get_field(attr_name)
if model_field:
return self._map_django_field_type(model_field)
except Exception:
pass
# 如果没有 queryset 或无法获取字段信息,使用启发式规则
return self._heuristic_field_type(attr_name)
def _map_django_field_type(self, model_field):
"""
将 Django 字段类型映射到 OpenAPI 类型
"""
field_type = type(model_field).__name__
# 整数类型
if 'Integer' in field_type or 'BigInteger' in field_type or 'SmallInteger' in field_type or 'AutoField' in field_type:
return 'integer'
# 浮点数类型
elif 'Float' in field_type or 'Decimal' in field_type:
return 'number'
# 布尔类型
elif 'Boolean' in field_type:
return 'boolean'
# 日期时间类型
elif 'DateTime' in field_type or 'Date' in field_type or 'Time' in field_type:
return 'string'
# 文件类型
elif 'File' in field_type or 'Image' in field_type:
return 'string'
# 其他类型默认为字符串
else:
return 'string'
def _heuristic_field_type(self, attr_name):
"""
启发式推断字段类型
"""
# 基于属性名的启发式规则
if attr_name in ['is_active', 'enabled', 'visible'] or attr_name.startswith('is_'):
return 'boolean'
elif attr_name in ['count', 'number', 'size', 'amount']:
return 'integer'
elif attr_name in ['price', 'rate', 'percentage']:
return 'number'
else:
# 默认返回字符串类型
return 'string'
def _get_openapi_properties_schema(self, field):
"""
获取对象属性的 OpenAPI schema
"""
return self._get_openapi_object_schema(field)['properties']
class LabeledChoiceFieldExtension(OpenApiSerializerFieldExtension):
"""
为 LabeledChoiceField 提供 OpenAPI schema
"""
target_class = LabeledChoiceField
def map_serializer_field(self, auto_schema, direction):
field = self.target
if getattr(field, 'many', False):
return {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'value': {'type': 'string'},
'label': {'type': 'string'}
}
},
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
else:
return {
'type': 'object',
'properties': {
'value': {'type': 'string'},
'label': {'type': 'string'}
},
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
class BitChoicesFieldExtension(OpenApiSerializerFieldExtension):
"""
为 BitChoicesField 提供 OpenAPI schema
"""
target_class = BitChoicesField
def map_serializer_field(self, auto_schema, direction):
field = self.target
return {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'value': {'type': 'string'},
'label': {'type': 'string'}
}
},
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
class LabelRelatedFieldExtension(OpenApiSerializerFieldExtension):
"""
为 LabelRelatedField 提供 OpenAPI schema
"""
target_class = 'common.serializers.fields.LabelRelatedField'
def map_serializer_field(self, auto_schema, direction):
field = self.target
# LabelRelatedField 返回一个包含 id, name, value, color 的对象
return {
'type': 'object',
'properties': {
'id': {
'type': 'string',
'description': 'Label ID'
},
'name': {
'type': 'string',
'description': 'Label name'
},
'value': {
'type': 'string',
'description': 'Label value'
},
'color': {
'type': 'string',
'description': 'Label color'
}
},
'required': ['id', 'name', 'value'],
'description': getattr(field, 'help_text', 'Label information'),
'title': getattr(field, 'label', 'Label'),
}
class DateTimeFieldExtension(OpenApiSerializerFieldExtension):
"""
为 DateTimeField 提供自定义 OpenAPI schema
修正 datetime 字段格式,使其符合实际返回格式 '%Y/%m/%d %H:%M:%S %z'
而不是标准的 ISO 8601 格式 (date-time)
"""
target_class = serializers.DateTimeField
def map_serializer_field(self, auto_schema, direction):
field = self.target
# 获取字段的描述信息,确保始终是字符串类型
help_text = getattr(field, 'help_text', None) or ''
description = help_text if isinstance(help_text, str) else ''
# 添加格式说明
format_desc = 'Format: YYYY/MM/DD HH:MM:SS +TZ (e.g., 2023/10/01 12:00:00 +0800)'
if description:
description = f'{description} {format_desc}'
else:
description = format_desc
# 返回字符串类型,不包含 format: date-time
# 因为实际返回格式是 '%Y/%m/%d %H:%M:%S %z',不是标准的 ISO 8601
schema = {
'type': 'string',
'description': description,
'title': getattr(field, 'label', '') or '',
'example': '2023/10/01 12:00:00 +0800',
}
return schema

View File

@@ -1,206 +0,0 @@
import re
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.generators import SchemaGenerator
__all__ = [
'CustomSchemaGenerator', 'CustomAutoSchema'
]
class CustomSchemaGenerator(SchemaGenerator):
from_mcp = False
def get_schema(self, request=None, public=False):
self.from_mcp = request.query_params.get('mcp') or request.path.endswith('swagger.json')
return super().get_schema(request, public)
class CustomAutoSchema(AutoSchema):
def __init__(self, *args, **kwargs):
self.from_mcp = True
super().__init__(*args, **kwargs)
def map_parsers(self):
return ['application/json']
def map_renderers(self, *args, **kwargs):
return ['application/json']
def get_tags(self):
operation_keys = self._tokenize_path()
if len(operation_keys) == 1:
return []
tags = ['_'.join(operation_keys[:2])]
return tags
def get_operation_id(self):
tokenized_path = self._tokenize_path()
# replace dashes as they can be problematic later in code generation
tokenized_path = [t.replace('-', '_') for t in tokenized_path]
action = ''
if hasattr(self.view, 'action'):
action = self.view.action
if not action:
if self.method == 'GET' and self._is_list_view():
action = 'list'
else:
action = self.method_mapping[self.method.lower()]
if action == "bulk_destroy":
action = "bulk_delete"
if not tokenized_path:
tokenized_path.append('root')
if re.search(r'<drf_format_suffix\w*:\w+>', self.path_regex):
tokenized_path.append('formatted')
return '_'.join(tokenized_path + [action])
def get_filter_parameters(self):
if not self.should_filter():
return []
fields = []
if hasattr(self.view, 'get_filter_backends'):
backends = self.view.get_filter_backends()
elif hasattr(self.view, 'filter_backends'):
backends = self.view.filter_backends
else:
backends = []
for filter_backend in backends:
fields += self.probe_inspectors(
self.filter_inspectors, 'get_filter_parameters', filter_backend()
) or []
return fields
def get_auth(self):
return [{'Bearer': []}]
def get_operation_security(self):
"""
重写操作安全配置,统一使用 Bearer token
"""
return [{'Bearer': []}]
def get_components_security_schemes(self):
"""
重写安全方案定义,避免认证类解析错误
"""
return {
'Bearer': {
'type': 'http',
'scheme': 'bearer',
'bearerFormat': 'JWT',
'description': 'JWT token for API authentication'
}
}
@staticmethod
def exclude_some_paths(path):
# 这里可以对 paths 进行处理
excludes = [
'/report/', '/render-to-json/', '/suggestions/',
'executions', 'automations', 'change-secret-records',
'change-secret-dashboard', '/copy-to-assets/',
'/move-to-assets/', 'dashboard', 'index', 'countries',
'/resources/cache/', 'profile/mfa', 'profile/password',
'profile/permissions', 'prometheus', 'constraints',
'/api/swagger.json', '/api/swagger.yaml',
]
for p in excludes:
if path.find(p) >= 0:
return True
return False
def exclude_some_models(self, model):
models = []
if self.from_mcp:
models = [
'users', 'user-groups',
'assets', 'hosts', 'devices', 'databases',
'webs', 'clouds', 'ds', 'platforms',
'nodes', 'zones', 'labels',
'accounts', 'account-templates',
'asset-permissions',
]
if models and model in models:
return False
return True
def exclude_some_apps(self, app):
apps = []
if self.from_mcp:
apps = [
'users', 'assets', 'accounts',
'perms', 'labels',
]
if apps and app in apps:
return False
return True
def exclude_some_app_model(self, path):
parts = path.split('/')
if len(parts) < 5 :
return True
if len(parts) == 7 and parts[5] != "{id}":
return True
if len(parts) > 7:
return True
app_name = parts[3]
if self.exclude_some_apps(app_name):
return True
if self.exclude_some_models(parts[4]):
return True
return False
def is_excluded(self):
if self.exclude_some_paths(self.path):
return True
if self.exclude_some_app_model(self.path):
return True
return False
def exclude_some_operations(self, operation_id):
exclude_operations = [
'orgs_orgs_read', 'orgs_orgs_update', 'orgs_orgs_delete',
'orgs_orgs_create', 'orgs_orgs_partial_update',
]
if operation_id in exclude_operations:
return True
if 'bulk' in operation_id:
return True
if 'destroy' in operation_id:
return True
if 'update' in operation_id and 'partial' not in operation_id:
return True
return False
def get_operation(self, path, *args, **kwargs):
operation = super().get_operation(path, *args, **kwargs)
if not operation:
return operation
operation_id = operation.get('operationId')
if self.exclude_some_operations(operation_id):
return None
if not operation.get('summary', ''):
operation['summary'] = operation.get('operationId')
# if self.is_excluded():
# return None
return operation

View File

@@ -37,11 +37,11 @@ class SchemeMixin:
} }
return Response(schema) return Response(schema)
# @method_decorator(cache_page(60 * 5,), name="dispatch") @method_decorator(cache_page(60 * 5,), name="dispatch")
class JsonApi(SchemeMixin, SpectacularJSONAPIView): class JsonApi(SchemeMixin, SpectacularJSONAPIView):
pass pass
# @method_decorator(cache_page(60 * 5,), name="dispatch") @method_decorator(cache_page(60 * 5,), name="dispatch")
class YamlApi(SchemeMixin, SpectacularYAMLAPIView): class YamlApi(SchemeMixin, SpectacularYAMLAPIView):
pass pass

View File

@@ -1,10 +1,98 @@
import httpx
import openai
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from rest_framework import status
from rest_framework.generics import GenericAPIView
from rest_framework.response import Response
from common.api import JMSModelViewSet from common.api import JMSModelViewSet
from common.permissions import IsValidUser, OnlySuperUser from common.permissions import IsValidUser, OnlySuperUser
from .. import serializers from .. import serializers
from ..const import ChatAITypeChoices
from ..models import ChatPrompt from ..models import ChatPrompt
from ..prompt import DefaultChatPrompt from ..prompt import DefaultChatPrompt
class ChatAITestingAPI(GenericAPIView):
serializer_class = serializers.ChatAISettingSerializer
rbac_perms = {
'POST': 'settings.change_chatai'
}
def get_config(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
data = self.serializer_class().data
data.update(serializer.validated_data)
for k, v in data.items():
if v:
continue
# 页面没有传递值, 从 settings 中获取
data[k] = getattr(settings, k, None)
return data
def post(self, request):
config = self.get_config(request)
chat_ai_enabled = config['CHAT_AI_ENABLED']
if not chat_ai_enabled:
return Response(
status=status.HTTP_400_BAD_REQUEST,
data={'msg': _('Chat AI is not enabled')}
)
tp = config['CHAT_AI_TYPE']
if tp == ChatAITypeChoices.gpt:
url = config['GPT_BASE_URL']
api_key = config['GPT_API_KEY']
proxy = config['GPT_PROXY']
model = config['GPT_MODEL']
else:
url = config['DEEPSEEK_BASE_URL']
api_key = config['DEEPSEEK_API_KEY']
proxy = config['DEEPSEEK_PROXY']
model = config['DEEPSEEK_MODEL']
kwargs = {
'base_url': url or None,
'api_key': api_key,
}
try:
if proxy:
kwargs['http_client'] = httpx.Client(
proxies=proxy,
transport=httpx.HTTPTransport(local_address='0.0.0.0')
)
client = openai.OpenAI(**kwargs)
ok = False
error = ''
client.chat.completions.create(
messages=[
{
"role": "user",
"content": "Say this is a test",
}
],
model=model,
)
ok = True
except openai.APIConnectionError as e:
error = str(e.__cause__) # an underlying Exception, likely raised within httpx.
except openai.APIStatusError as e:
error = str(e.message)
except Exception as e:
ok, error = False, str(e)
if ok:
_status, msg = status.HTTP_200_OK, _('Test success')
else:
_status, msg = status.HTTP_400_BAD_REQUEST, error
return Response(status=_status, data={'msg': msg})
class ChatPromptViewSet(JMSModelViewSet): class ChatPromptViewSet(JMSModelViewSet):
serializer_classes = { serializer_classes = {
'default': serializers.ChatPromptSerializer, 'default': serializers.ChatPromptSerializer,

View File

@@ -154,10 +154,7 @@ class SettingsApi(generics.RetrieveUpdateAPIView):
def parse_serializer_data(self, serializer): def parse_serializer_data(self, serializer):
data = [] data = []
fields = self.get_fields() fields = self.get_fields()
encrypted_items = [ encrypted_items = [name for name, field in fields.items() if field.write_only]
name for name, field in fields.items()
if field.write_only or getattr(field, 'encrypted', False)
]
category = self.request.query_params.get('category', '') category = self.request.query_params.get('category', '')
for name, value in serializer.validated_data.items(): for name, value in serializer.validated_data.items():
encrypted = name in encrypted_items encrypted = name in encrypted_items

View File

@@ -14,5 +14,18 @@ class ChatAIMethodChoices(TextChoices):
class ChatAITypeChoices(TextChoices): class ChatAITypeChoices(TextChoices):
openai = 'openai', 'Openai' gpt = 'gpt', 'GPT'
ollama = 'ollama', 'Ollama' deep_seek = 'deep-seek', 'DeepSeek'
class GPTModelChoices(TextChoices):
gpt_4o_mini = 'gpt-4o-mini', 'gpt-4o-mini'
gpt_4o = 'gpt-4o', 'gpt-4o'
o3_mini = 'o3-mini', 'o3-mini'
o1_mini = 'o1-mini', 'o1-mini'
o1 = 'o1', 'o1'
class DeepSeekModelChoices(TextChoices):
deepseek_chat = 'deepseek-chat', 'DeepSeek-V3'
deepseek_reasoner = 'deepseek-reasoner', 'DeepSeek-R1'

View File

@@ -1,7 +1,6 @@
import json import json
import os import os
import shutil import shutil
from typing import Any, Dict, List
from django.conf import settings from django.conf import settings
from django.core.files.base import ContentFile from django.core.files.base import ContentFile
@@ -15,6 +14,7 @@ from rest_framework.utils.encoders import JSONEncoder
from common.db.models import JMSBaseModel from common.db.models import JMSBaseModel
from common.db.utils import Encryptor from common.db.utils import Encryptor
from common.utils import get_logger from common.utils import get_logger
from .const import ChatAITypeChoices
from .signals import setting_changed from .signals import setting_changed
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -196,25 +196,20 @@ class ChatPrompt(JMSBaseModel):
return self.name return self.name
def get_chatai_data() -> Dict[str, Any]: def get_chatai_data():
raw_providers = settings.CHAT_AI_PROVIDERS data = {
providers: List[dict] = [p for p in raw_providers if isinstance(p, dict)] 'url': settings.GPT_BASE_URL,
'api_key': settings.GPT_API_KEY,
if not providers: 'proxy': settings.GPT_PROXY,
return {} 'model': settings.GPT_MODEL,
selected = next(
(p for p in providers if p.get('is_assistant')),
providers[0],
)
return {
'url': selected.get('base_url'),
'api_key': selected.get('api_key'),
'proxy': selected.get('proxy'),
'model': selected.get('model'),
'name': selected.get('name'),
} }
if settings.CHAT_AI_TYPE != ChatAITypeChoices.gpt:
data['url'] = settings.DEEPSEEK_BASE_URL
data['api_key'] = settings.DEEPSEEK_API_KEY
data['proxy'] = settings.DEEPSEEK_PROXY
data['model'] = settings.DEEPSEEK_MODEL
return data
def init_sqlite_db(): def init_sqlite_db():

View File

@@ -10,12 +10,11 @@ from common.utils import date_expired_default
__all__ = [ __all__ = [
'AnnouncementSettingSerializer', 'OpsSettingSerializer', 'VaultSettingSerializer', 'AnnouncementSettingSerializer', 'OpsSettingSerializer', 'VaultSettingSerializer',
'HashicorpKVSerializer', 'AzureKVSerializer', 'TicketSettingSerializer', 'HashicorpKVSerializer', 'AzureKVSerializer', 'TicketSettingSerializer',
'ChatAIProviderSerializer', 'ChatAISettingSerializer', 'ChatAISettingSerializer', 'VirtualAppSerializer', 'AmazonSMSerializer',
'VirtualAppSerializer', 'AmazonSMSerializer',
] ]
from settings.const import ( from settings.const import (
ChatAITypeChoices, ChatAIMethodChoices ChatAITypeChoices, GPTModelChoices, DeepSeekModelChoices, ChatAIMethodChoices
) )
@@ -121,29 +120,6 @@ class AmazonSMSerializer(serializers.Serializer):
) )
class ChatAIProviderListSerializer(serializers.ListSerializer):
# 标记整个列表需要加密存储,避免明文保存 API Key
encrypted = True
class ChatAIProviderSerializer(serializers.Serializer):
type = serializers.ChoiceField(
default=ChatAITypeChoices.openai, choices=ChatAITypeChoices.choices,
label=_("Types"), required=False,
)
base_url = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.')
)
api_key = EncryptedField(
allow_blank=True, required=False, label=_('API Key'),
)
proxy = serializers.CharField(
allow_blank=True, required=False, label=_('Proxy'),
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
)
class ChatAISettingSerializer(serializers.Serializer): class ChatAISettingSerializer(serializers.Serializer):
PREFIX_TITLE = _('Chat AI') PREFIX_TITLE = _('Chat AI')
@@ -154,14 +130,44 @@ class ChatAISettingSerializer(serializers.Serializer):
default=ChatAIMethodChoices.api, choices=ChatAIMethodChoices.choices, default=ChatAIMethodChoices.api, choices=ChatAIMethodChoices.choices,
label=_("Method"), required=False, label=_("Method"), required=False,
) )
CHAT_AI_PROVIDERS = ChatAIProviderListSerializer(
child=ChatAIProviderSerializer(),
allow_empty=True, required=False, default=list, label=_('Providers')
)
CHAT_AI_EMBED_URL = serializers.CharField( CHAT_AI_EMBED_URL = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'), allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.') help_text=_('The base URL of the Chat service.')
) )
CHAT_AI_TYPE = serializers.ChoiceField(
default=ChatAITypeChoices.gpt, choices=ChatAITypeChoices.choices,
label=_("Types"), required=False,
)
GPT_BASE_URL = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.')
)
GPT_API_KEY = EncryptedField(
allow_blank=True, required=False, label=_('API Key'),
)
GPT_PROXY = serializers.CharField(
allow_blank=True, required=False, label=_('Proxy'),
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
)
GPT_MODEL = serializers.ChoiceField(
default=GPTModelChoices.gpt_4o_mini, choices=GPTModelChoices.choices,
label=_("GPT Model"), required=False,
)
DEEPSEEK_BASE_URL = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.')
)
DEEPSEEK_API_KEY = EncryptedField(
allow_blank=True, required=False, label=_('API Key'),
)
DEEPSEEK_PROXY = serializers.CharField(
allow_blank=True, required=False, label=_('Proxy'),
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
)
DEEPSEEK_MODEL = serializers.ChoiceField(
default=DeepSeekModelChoices.deepseek_chat, choices=DeepSeekModelChoices.choices,
label=_("DeepSeek Model"), required=False,
)
class TicketSettingSerializer(serializers.Serializer): class TicketSettingSerializer(serializers.Serializer):

View File

@@ -73,6 +73,7 @@ class PrivateSettingSerializer(PublicSettingSerializer):
CHAT_AI_ENABLED = serializers.BooleanField() CHAT_AI_ENABLED = serializers.BooleanField()
CHAT_AI_METHOD = serializers.CharField() CHAT_AI_METHOD = serializers.CharField()
CHAT_AI_EMBED_URL = serializers.CharField() CHAT_AI_EMBED_URL = serializers.CharField()
CHAT_AI_TYPE = serializers.CharField()
GPT_MODEL = serializers.CharField() GPT_MODEL = serializers.CharField()
FILE_UPLOAD_SIZE_LIMIT_MB = serializers.IntegerField() FILE_UPLOAD_SIZE_LIMIT_MB = serializers.IntegerField()
FTP_FILE_MAX_STORE = serializers.IntegerField() FTP_FILE_MAX_STORE = serializers.IntegerField()

View File

@@ -21,6 +21,7 @@ urlpatterns = [
path('sms/<str:backend>/testing/', api.SMSTestingAPI.as_view(), name='sms-testing'), path('sms/<str:backend>/testing/', api.SMSTestingAPI.as_view(), name='sms-testing'),
path('sms/backend/', api.SMSBackendAPI.as_view(), name='sms-backend'), path('sms/backend/', api.SMSBackendAPI.as_view(), name='sms-backend'),
path('vault/<str:backend>/testing/', api.VaultTestingAPI.as_view(), name='vault-testing'), path('vault/<str:backend>/testing/', api.VaultTestingAPI.as_view(), name='vault-testing'),
path('chatai/testing/', api.ChatAITestingAPI.as_view(), name='chatai-testing'),
path('vault/sync/', api.VaultSyncDataAPI.as_view(), name='vault-sync'), path('vault/sync/', api.VaultSyncDataAPI.as_view(), name='vault-sync'),
path('security/block-ip/', api.BlockIPSecurityAPI.as_view(), name='block-ip'), path('security/block-ip/', api.BlockIPSecurityAPI.as_view(), name='block-ip'),
path('security/unlock-ip/', api.UnlockIPSecurityAPI.as_view(), name='unlock-ip'), path('security/unlock-ip/', api.UnlockIPSecurityAPI.as_view(), name='unlock-ip'),

View File

@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from .applet import * from .applet import *
from .chat import *
from .component import * from .component import *
from .session import * from .session import *
from .virtualapp import * from .virtualapp import *

View File

@@ -1 +0,0 @@
from .chat import *

View File

@@ -1,15 +0,0 @@
from common.api import JMSBulkModelViewSet
from terminal import serializers
from terminal.filters import ChatFilter
from terminal.models import Chat
__all__ = ['ChatViewSet']
class ChatViewSet(JMSBulkModelViewSet):
queryset = Chat.objects.all()
serializer_class = serializers.ChatSerializer
filterset_class = ChatFilter
search_fields = ['title']
ordering_fields = ['date_updated']
ordering = ['-date_updated']

View File

@@ -2,7 +2,7 @@ from django.db.models import QuerySet
from django_filters import rest_framework as filters from django_filters import rest_framework as filters
from orgs.utils import filter_org_queryset from orgs.utils import filter_org_queryset
from terminal.models import Command, CommandStorage, Session, Chat from terminal.models import Command, CommandStorage, Session
class CommandFilter(filters.FilterSet): class CommandFilter(filters.FilterSet):
@@ -79,34 +79,7 @@ class CommandStorageFilter(filters.FilterSet):
model = CommandStorage model = CommandStorage
fields = ['real', 'name', 'type', 'is_default'] fields = ['real', 'name', 'type', 'is_default']
@staticmethod def filter_real(self, queryset, name, value):
def filter_real(queryset, name, value):
if value: if value:
queryset = queryset.exclude(name='null') queryset = queryset.exclude(name='null')
return queryset return queryset
class ChatFilter(filters.FilterSet):
ids = filters.BooleanFilter(method='filter_ids')
folder_ids = filters.BooleanFilter(method='filter_folder_ids')
class Meta:
model = Chat
fields = [
'title', 'user_id', 'pinned', 'folder_id',
'archived', 'socket_id', 'share_id'
]
@staticmethod
def filter_ids(queryset, name, value):
ids = value.split(',')
queryset = queryset.filter(id__in=ids)
return queryset
@staticmethod
def filter_folder_ids(queryset, name, value):
ids = value.split(',')
queryset = queryset.filter(folder_id__in=ids)
return queryset

View File

@@ -1,38 +0,0 @@
# Generated by Django 4.1.13 on 2025-09-30 06:57
from django.db import migrations, models
import uuid
class Migration(migrations.Migration):
dependencies = [
('terminal', '0010_alter_command_risk_level_alter_session_login_from_and_more'),
]
operations = [
migrations.CreateModel(
name='Chat',
fields=[
('created_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Created by')),
('updated_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Updated by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('comment', models.TextField(blank=True, default='', verbose_name='Comment')),
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('title', models.CharField(max_length=256, verbose_name='Title')),
('chat', models.JSONField(default=dict, verbose_name='Chat')),
('meta', models.JSONField(default=dict, verbose_name='Meta')),
('pinned', models.BooleanField(default=False, verbose_name='Pinned')),
('archived', models.BooleanField(default=False, verbose_name='Archived')),
('share_id', models.CharField(blank=True, default='', max_length=36)),
('folder_id', models.CharField(blank=True, default='', max_length=36)),
('socket_id', models.CharField(blank=True, default='', max_length=36)),
('user_id', models.CharField(blank=True, db_index=True, default='', max_length=36)),
('session_info', models.JSONField(default=dict, verbose_name='Session Info')),
],
options={
'verbose_name': 'Chat',
},
),
]

View File

@@ -1,5 +1,4 @@
from .applet import *
from .chat import *
from .component import *
from .session import * from .session import *
from .component import *
from .applet import *
from .virtualapp import * from .virtualapp import *

View File

@@ -1 +0,0 @@
from .chat import *

View File

@@ -1,30 +0,0 @@
from django.db import models
from django.utils.translation import gettext_lazy as _
from common.db.models import JMSBaseModel
from common.utils import get_logger
logger = get_logger(__name__)
__all__ = ['Chat']
class Chat(JMSBaseModel):
# id == session_id # 36 chars
title = models.CharField(max_length=256, verbose_name=_('Title'))
chat = models.JSONField(default=dict, verbose_name=_('Chat'))
meta = models.JSONField(default=dict, verbose_name=_('Meta'))
pinned = models.BooleanField(default=False, verbose_name=_('Pinned'))
archived = models.BooleanField(default=False, verbose_name=_('Archived'))
share_id = models.CharField(blank=True, default='', max_length=36)
folder_id = models.CharField(blank=True, default='', max_length=36)
socket_id = models.CharField(blank=True, default='', max_length=36)
user_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
session_info = models.JSONField(default=dict, verbose_name=_('Session Info'))
class Meta:
verbose_name = _('Chat')
def __str__(self):
return self.title

View File

@@ -123,11 +123,11 @@ class Terminal(StorageMixin, TerminalStatusMixin, JMSBaseModel):
def get_chat_ai_setting(): def get_chat_ai_setting():
data = get_chatai_data() data = get_chatai_data()
return { return {
'GPT_BASE_URL': data.get('url'), 'GPT_BASE_URL': data['url'],
'GPT_API_KEY': data.get('api_key'), 'GPT_API_KEY': data['api_key'],
'GPT_PROXY': data.get('proxy'), 'GPT_PROXY': data['proxy'],
'GPT_MODEL': data.get('model'), 'GPT_MODEL': data['model'],
'CHAT_AI_PROVIDERS': settings.CHAT_AI_PROVIDERS, 'CHAT_AI_TYPE': settings.CHAT_AI_TYPE,
} }
@staticmethod @staticmethod

View File

@@ -2,10 +2,8 @@
# #
from .applet import * from .applet import *
from .applet_host import * from .applet_host import *
from .chat import *
from .command import * from .command import *
from .endpoint import * from .endpoint import *
from .loki import *
from .session import * from .session import *
from .sharing import * from .sharing import *
from .storage import * from .storage import *
@@ -13,3 +11,4 @@ from .task import *
from .terminal import * from .terminal import *
from .virtualapp import * from .virtualapp import *
from .virtualapp_provider import * from .virtualapp_provider import *
from .loki import *

View File

@@ -1,28 +0,0 @@
from rest_framework import serializers
from common.serializers import CommonBulkModelSerializer
from terminal.models import Chat
__all__ = ['ChatSerializer']
class ChatSerializer(CommonBulkModelSerializer):
created_at = serializers.SerializerMethodField()
updated_at = serializers.SerializerMethodField()
class Meta:
model = Chat
fields_mini = ['id', 'title', 'created_at', 'updated_at']
fields = fields_mini + [
'chat', 'meta', 'pinned', 'archived',
'share_id', 'folder_id',
'user_id', 'session_info'
]
@staticmethod
def get_created_at(obj):
return int(obj.date_created.timestamp())
@staticmethod
def get_updated_at(obj):
return int(obj.date_updated.timestamp())

View File

@@ -32,7 +32,6 @@ router.register(r'virtual-apps', api.VirtualAppViewSet, 'virtual-app')
router.register(r'app-providers', api.AppProviderViewSet, 'app-provider') router.register(r'app-providers', api.AppProviderViewSet, 'app-provider')
router.register(r'app-providers/((?P<provider>[^/.]+)/)?apps', api.AppProviderAppViewSet, 'app-provider-app') router.register(r'app-providers/((?P<provider>[^/.]+)/)?apps', api.AppProviderAppViewSet, 'app-provider-app')
router.register(r'virtual-app-publications', api.VirtualAppPublicationViewSet, 'virtual-app-publication') router.register(r'virtual-app-publications', api.VirtualAppPublicationViewSet, 'virtual-app-publication')
router.register(r'chats', api.ChatViewSet, 'chat')
urlpatterns = [ urlpatterns = [
path('my-sessions/', api.MySessionAPIView.as_view(), name='my-session'), path('my-sessions/', api.MySessionAPIView.as_view(), name='my-session'),

View File

@@ -199,19 +199,11 @@ class UserChangePasswordApi(UserQuerysetMixin, generics.UpdateAPIView):
class UserUnblockPKApi(UserQuerysetMixin, generics.UpdateAPIView): class UserUnblockPKApi(UserQuerysetMixin, generics.UpdateAPIView):
serializer_class = serializers.UserSerializer serializer_class = serializers.UserSerializer
def get_object(self):
pk = self.kwargs.get('pk')
if is_uuid(pk):
return super().get_object()
else:
return self.get_queryset().filter(username=pk).first()
def perform_update(self, serializer): def perform_update(self, serializer):
user = self.get_object() user = self.get_object()
if not user: username = user.username if user else ''
return Response({"error": _("User not found")}, status=404) LoginBlockUtil.unblock_user(username)
MFABlockUtils.unblock_user(username)
user.unblock_login()
class UserResetMFAApi(UserQuerysetMixin, generics.RetrieveAPIView): class UserResetMFAApi(UserQuerysetMixin, generics.RetrieveAPIView):

View File

@@ -274,8 +274,8 @@ class User(
LoginBlockUtil.unblock_user(self.username) LoginBlockUtil.unblock_user(self.username)
MFABlockUtils.unblock_user(self.username) MFABlockUtils.unblock_user(self.username)
@property @lazyproperty
def is_login_blocked(self): def login_blocked(self):
from users.utils import LoginBlockUtil, MFABlockUtils from users.utils import LoginBlockUtil, MFABlockUtils
if LoginBlockUtil.is_user_block(self.username): if LoginBlockUtil.is_user_block(self.username):
@@ -284,13 +284,6 @@ class User(
return True return True
return False return False
@classmethod
def block_login(cls, username):
from users.utils import LoginBlockUtil, MFABlockUtils
LoginBlockUtil.block_user(username)
MFABlockUtils.block_user(username)
def delete(self, using=None, keep_parents=False): def delete(self, using=None, keep_parents=False):
if self.pk == 1 or self.username == "admin": if self.pk == 1 or self.username == "admin":
raise PermissionDenied(_("Can not delete admin user")) raise PermissionDenied(_("Can not delete admin user"))

View File

@@ -123,7 +123,7 @@ class UserSerializer(
mfa_force_enabled = serializers.BooleanField( mfa_force_enabled = serializers.BooleanField(
read_only=True, label=_("MFA force enabled") read_only=True, label=_("MFA force enabled")
) )
is_login_blocked = serializers.BooleanField(read_only=True, label=_("Login blocked")) login_blocked = serializers.BooleanField(read_only=True, label=_("Login blocked"))
is_expired = serializers.BooleanField(read_only=True, label=_("Is expired")) is_expired = serializers.BooleanField(read_only=True, label=_("Is expired"))
is_valid = serializers.BooleanField(read_only=True, label=_("Is valid")) is_valid = serializers.BooleanField(read_only=True, label=_("Is valid"))
is_otp_secret_key_bound = serializers.BooleanField( is_otp_secret_key_bound = serializers.BooleanField(
@@ -193,7 +193,6 @@ class UserSerializer(
"is_valid", "is_expired", "is_active", # 布尔字段 "is_valid", "is_expired", "is_active", # 布尔字段
"is_otp_secret_key_bound", "can_public_key_auth", "is_otp_secret_key_bound", "can_public_key_auth",
"mfa_enabled", "need_update_password", "is_face_code_set", "mfa_enabled", "need_update_password", "is_face_code_set",
"is_login_blocked",
] ]
# 包含不太常用的字段,可以没有 # 包含不太常用的字段,可以没有
fields_verbose = ( fields_verbose = (
@@ -212,7 +211,7 @@ class UserSerializer(
# 多对多字段 # 多对多字段
fields_m2m = ["groups", "system_roles", "org_roles", "orgs_roles", "labels"] fields_m2m = ["groups", "system_roles", "org_roles", "orgs_roles", "labels"]
# 在serializer 上定义的字段 # 在serializer 上定义的字段
fields_custom = ["is_login_blocked", "password_strategy"] fields_custom = ["login_blocked", "password_strategy"]
fields = fields_verbose + fields_fk + fields_m2m + fields_custom fields = fields_verbose + fields_fk + fields_m2m + fields_custom
fields_unexport = ["avatar_url", "is_service_account"] fields_unexport = ["avatar_url", "is_service_account"]

View File

@@ -28,6 +28,6 @@ urlpatterns = [
path('users/<uuid:pk>/password/', api.UserChangePasswordApi.as_view(), name='change-user-password'), path('users/<uuid:pk>/password/', api.UserChangePasswordApi.as_view(), name='change-user-password'),
path('users/<uuid:pk>/password/reset/', api.UserResetPasswordApi.as_view(), name='user-reset-password'), path('users/<uuid:pk>/password/reset/', api.UserResetPasswordApi.as_view(), name='user-reset-password'),
path('users/<uuid:pk>/pubkey/reset/', api.UserResetPKApi.as_view(), name='user-public-key-reset'), path('users/<uuid:pk>/pubkey/reset/', api.UserResetPKApi.as_view(), name='user-public-key-reset'),
path('users/<str:pk>/unblock/', api.UserUnblockPKApi.as_view(), name='user-unblock'), path('users/<uuid:pk>/unblock/', api.UserUnblockPKApi.as_view(), name='user-unblock'),
] ]
urlpatterns += router.urls urlpatterns += router.urls

View File

@@ -186,13 +186,6 @@ class BlockUtilBase:
def is_block(self): def is_block(self):
return bool(cache.get(self.block_key)) return bool(cache.get(self.block_key))
@classmethod
def block_user(cls, username):
username = username.lower()
block_key = cls.BLOCK_KEY_TMPL.format(username)
key_ttl = int(settings.SECURITY_LOGIN_LIMIT_TIME) * 60
cache.set(block_key, True, key_ttl)
@classmethod @classmethod
def get_blocked_usernames(cls): def get_blocked_usernames(cls):
key = cls.BLOCK_KEY_TMPL.format('*') key = cls.BLOCK_KEY_TMPL.format('*')