mirror of
https://github.com/jumpserver/jumpserver.git
synced 2025-12-15 08:32:48 +00:00
Compare commits
5 Commits
refactor_p
...
pr@v5@mcp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e41f6e27e2 | ||
|
|
d2386fb56c | ||
|
|
5f1ba56e56 | ||
|
|
2b1fdb937b | ||
|
|
1e754546f1 |
@@ -150,6 +150,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
auto_config = serializers.DictField(read_only=True, label=_('Auto info'))
|
||||
platform = ObjectRelatedField(queryset=Platform.objects, required=True, label=_('Platform'),
|
||||
attrs=('id', 'name', 'type'))
|
||||
spec_info = serializers.DictField(read_only=True, label=_('Spec info'))
|
||||
accounts_amount = serializers.IntegerField(read_only=True, label=_('Accounts amount'))
|
||||
_accounts = None
|
||||
|
||||
@@ -164,7 +165,8 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
'directory_services',
|
||||
]
|
||||
read_only_fields = [
|
||||
'accounts_amount', 'category', 'type', 'connectivity', 'auto_config',
|
||||
'accounts_amount', 'category', 'type', 'connectivity',
|
||||
'auto_config', 'spec_info',
|
||||
'date_verified', 'created_by', 'date_created', 'date_updated',
|
||||
]
|
||||
fields = fields_small + fields_fk + fields_m2m + read_only_fields
|
||||
@@ -186,6 +188,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
super().__init__(*args, **kwargs)
|
||||
self._init_field_choices()
|
||||
self._extract_accounts()
|
||||
self._set_platform()
|
||||
|
||||
def _extract_accounts(self):
|
||||
if not getattr(self, 'initial_data', None):
|
||||
@@ -217,6 +220,21 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
protocols_data = [{'name': p.name, 'port': p.port} for p in protocols]
|
||||
self.initial_data['protocols'] = protocols_data
|
||||
|
||||
def _set_platform(self):
|
||||
if not hasattr(self, 'initial_data'):
|
||||
return
|
||||
platform_id = self.initial_data.get('platform')
|
||||
if not platform_id:
|
||||
return
|
||||
|
||||
if isinstance(platform_id, int) or str(platform_id).isdigit() or not isinstance(platform_id, str):
|
||||
return
|
||||
|
||||
platform = Platform.objects.filter(name=platform_id).first()
|
||||
if not platform:
|
||||
return
|
||||
self.initial_data['platform'] = platform.id
|
||||
|
||||
def _init_field_choices(self):
|
||||
request = self.context.get('request')
|
||||
if not request:
|
||||
@@ -231,6 +249,19 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
return
|
||||
field_type.choices = AllTypes.filter_choices(category)
|
||||
|
||||
@staticmethod
|
||||
def get_spec_info(obj):
|
||||
return {}
|
||||
|
||||
def get_auto_config(self, obj):
|
||||
return obj.auto_config()
|
||||
|
||||
def get_gathered_info(self, obj):
|
||||
return obj.gathered_info()
|
||||
|
||||
def get_accounts_amount(self, obj):
|
||||
return obj.accounts_amount()
|
||||
|
||||
@classmethod
|
||||
def setup_eager_loading(cls, queryset):
|
||||
""" Perform necessary eager loading of data. """
|
||||
@@ -265,8 +296,10 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
|
||||
if not platform_id and self.instance:
|
||||
platform = self.instance.platform
|
||||
else:
|
||||
elif isinstance(platform_id, int):
|
||||
platform = Platform.objects.filter(id=platform_id).first()
|
||||
else:
|
||||
platform = Platform.objects.filter(name=platform_id).first()
|
||||
|
||||
if not platform:
|
||||
raise serializers.ValidationError({'platform': _("Platform not exist")})
|
||||
@@ -297,6 +330,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
|
||||
def is_valid(self, raise_exception=False):
|
||||
self._set_protocols_default()
|
||||
self._set_platform()
|
||||
return super().is_valid(raise_exception=raise_exception)
|
||||
|
||||
def validate_protocols(self, protocols_data):
|
||||
@@ -422,7 +456,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
|
||||
class DetailMixin(serializers.Serializer):
|
||||
accounts = AssetAccountSerializer(many=True, required=False, label=_('Accounts'))
|
||||
spec_info = MethodSerializer(label=_('Spec info'), read_only=True)
|
||||
spec_info = MethodSerializer(label=_('Spec info'), read_only=True, required=False)
|
||||
gathered_info = MethodSerializer(label=_('Gathered info'), read_only=True)
|
||||
auto_config = serializers.DictField(read_only=True, label=_('Auto info'))
|
||||
|
||||
|
||||
@@ -67,6 +67,7 @@ class UserLoginMFAView(mixins.AuthMixin, FormView):
|
||||
def get_context_data(self, **kwargs):
|
||||
user = self.get_user_from_session()
|
||||
mfa_context = self.get_user_mfa_context(user)
|
||||
print(mfa_context)
|
||||
kwargs.update(mfa_context)
|
||||
return kwargs
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ class Device:
|
||||
self.__load_driver(driver_path)
|
||||
# open device
|
||||
self.__open_device()
|
||||
self.__reset_key_store()
|
||||
|
||||
def close(self):
|
||||
if self.__device is None:
|
||||
@@ -68,3 +69,12 @@ class Device:
|
||||
if ret != 0:
|
||||
raise PiicoError("open piico device failed", ret)
|
||||
self.__device = device
|
||||
|
||||
def __reset_key_store(self):
|
||||
if self._driver is None:
|
||||
raise PiicoError("no driver loaded", 0)
|
||||
if self.__device is None:
|
||||
raise PiicoError("device not open", 0)
|
||||
ret = self._driver.SPII_ResetModule(self.__device)
|
||||
if ret != 0:
|
||||
raise PiicoError("reset device failed", ret)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from .aggregate import *
|
||||
from .dashboard import IndexApi
|
||||
from .health import PrometheusMetricsApi, HealthCheckView
|
||||
from .search import GlobalSearchView
|
||||
@@ -1,9 +0,0 @@
|
||||
from .detail import ResourceDetailApi
|
||||
from .list import ResourceListApi
|
||||
from .supported import ResourceTypeListApi
|
||||
|
||||
__all__ = [
|
||||
'ResourceListApi',
|
||||
'ResourceDetailApi',
|
||||
'ResourceTypeListApi',
|
||||
]
|
||||
@@ -1,57 +0,0 @@
|
||||
list_params = [
|
||||
{
|
||||
"name": "search",
|
||||
"in": "query",
|
||||
"description": "A search term.",
|
||||
"required": False,
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "order",
|
||||
"in": "query",
|
||||
"description": "Which field to use when ordering the results.",
|
||||
"required": False,
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"in": "query",
|
||||
"description": "Number of results to return per page. Default is 10.",
|
||||
"required": False,
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"name": "offset",
|
||||
"in": "query",
|
||||
"description": "The initial index from which to return the results.",
|
||||
"required": False,
|
||||
"type": "integer"
|
||||
},
|
||||
|
||||
]
|
||||
|
||||
common_params = [
|
||||
{
|
||||
"name": "resource",
|
||||
"in": "path",
|
||||
"description": """Resource to query, e.g. users, assets, permissions, acls, user-groups, policies, nodes, hosts,
|
||||
devices, clouds, webs, databases,
|
||||
gpts, ds, customs, platforms, zones, gateways, protocol-settings, labels, virtual-accounts,
|
||||
gathered-accounts, account-templates, account-template-secrets, account-backups, account-backup-executions,
|
||||
change-secret-automations, change-secret-executions, change-secret-records, gather-account-automations,
|
||||
gather-account-executions, push-account-automations, push-account-executions, push-account-records,
|
||||
check-account-automations, check-account-executions, account-risks, integration-apps, asset-permissions,
|
||||
zones, gateways, virtual-accounts, gathered-accounts, account-templates, account-template-secrets,,
|
||||
GET /api/v1/resources/ to get full supported resource.
|
||||
""",
|
||||
"required": True,
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"name": "X-JMS-ORG",
|
||||
"in": "header",
|
||||
"description": "The organization ID to use for the request. Organization is the namespace for resources, if not set, use default org",
|
||||
"required": False,
|
||||
"type": "string"
|
||||
}
|
||||
]
|
||||
@@ -1,75 +0,0 @@
|
||||
# views.py
|
||||
|
||||
from drf_spectacular.utils import extend_schema, OpenApiParameter
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from .const import common_params
|
||||
from .proxy import ProxyMixin
|
||||
from .utils import param_dic_to_param
|
||||
|
||||
one_param = [
|
||||
{
|
||||
'name': 'id',
|
||||
'in': 'path',
|
||||
'required': True,
|
||||
'description': 'Resource ID',
|
||||
'type': 'string',
|
||||
}
|
||||
]
|
||||
|
||||
object_params = [
|
||||
param_dic_to_param(d)
|
||||
for d in common_params + one_param
|
||||
]
|
||||
|
||||
|
||||
class ResourceDetailApi(ProxyMixin, APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
@extend_schema(
|
||||
operation_id="get_resource_detail",
|
||||
summary="Get resource detail",
|
||||
parameters=object_params,
|
||||
description="""
|
||||
Get resource detail.
|
||||
{resource} is the resource name, GET /api/v1/resources/ to get full supported resource.
|
||||
""",
|
||||
)
|
||||
def get(self, request, resource, pk=None):
|
||||
return self._proxy(request, resource, pk=pk, action='retrieve')
|
||||
|
||||
@extend_schema(
|
||||
operation_id="delete_resource",
|
||||
summary="Delete the resource",
|
||||
parameters=object_params,
|
||||
description="Delete the resource, and can not be restored",
|
||||
)
|
||||
def delete(self, request, resource, pk=None):
|
||||
return self._proxy(request, resource, pk, action='destroy')
|
||||
|
||||
@extend_schema(
|
||||
operation_id="update_resource",
|
||||
summary="Update the resource property",
|
||||
parameters=object_params,
|
||||
description="""
|
||||
Update the resource property, all property will be update,
|
||||
{resource} is the resource name, GET /api/v1/resources/ to get full supported resource.
|
||||
|
||||
OPTION /api/v1/resources/{resource}/{id}/?action=put to get field type and helptext.
|
||||
""",
|
||||
)
|
||||
def put(self, request, resource, pk=None):
|
||||
return self._proxy(request, resource, pk, action='update')
|
||||
|
||||
@extend_schema(
|
||||
operation_id="partial_update_resource",
|
||||
summary="Update the resource property",
|
||||
parameters=object_params,
|
||||
description="""
|
||||
Partial update the resource property, only request property will be update,
|
||||
OPTION /api/v1/resources/{resource}/{id}/?action=patch to get field type and helptext.
|
||||
""",
|
||||
)
|
||||
def patch(self, request, resource, pk=None):
|
||||
return self._proxy(request, resource, pk, action='partial_update')
|
||||
@@ -1,87 +0,0 @@
|
||||
# views.py
|
||||
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.routers import DefaultRouter
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from .const import list_params, common_params
|
||||
from .proxy import ProxyMixin
|
||||
from .utils import param_dic_to_param
|
||||
|
||||
router = DefaultRouter()
|
||||
|
||||
BASE_URL = "http://localhost:8080"
|
||||
|
||||
list_params = [
|
||||
param_dic_to_param(d)
|
||||
for d in list_params + common_params
|
||||
]
|
||||
|
||||
create_params = [
|
||||
param_dic_to_param(d)
|
||||
for d in common_params
|
||||
]
|
||||
|
||||
list_schema = {
|
||||
"required": [
|
||||
"count",
|
||||
"results"
|
||||
],
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"next": {
|
||||
"type": "string",
|
||||
"format": "uri",
|
||||
"x-nullable": True
|
||||
},
|
||||
"previous": {
|
||||
"type": "string",
|
||||
"format": "uri",
|
||||
"x-nullable": True
|
||||
},
|
||||
"results": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
from drf_spectacular.openapi import OpenApiResponse, OpenApiExample
|
||||
|
||||
|
||||
class ResourceListApi(ProxyMixin, APIView):
|
||||
@extend_schema(
|
||||
operation_id="get_resource_list",
|
||||
summary="Get resource list",
|
||||
parameters=list_params,
|
||||
responses={200: OpenApiResponse(description="Resource list response")},
|
||||
description="""
|
||||
Get resource list, you should set the resource name in the url.
|
||||
OPTIONS /api/v1/resources/{resource}/?action=get to get every type resource's field type and help text.
|
||||
""",
|
||||
)
|
||||
# ↓↓↓ Swagger 自动文档 ↓↓↓
|
||||
def get(self, request, resource):
|
||||
return self._proxy(request, resource)
|
||||
|
||||
@extend_schema(
|
||||
operation_id="create_resource_by_type",
|
||||
summary="Create resource",
|
||||
parameters=create_params,
|
||||
description="""
|
||||
Create resource,
|
||||
OPTIONS /api/v1/resources/{resource}/?action=post to get every resource type field type and helptext, and
|
||||
you will know how to create it.
|
||||
""",
|
||||
)
|
||||
def post(self, request, resource, pk=None):
|
||||
if not resource:
|
||||
resource = request.data.pop('resource', '')
|
||||
return self._proxy(request, resource, pk, action='create')
|
||||
|
||||
def options(self, request, resource, pk=None):
|
||||
return self._proxy(request, resource, pk, action='metadata')
|
||||
@@ -1,75 +0,0 @@
|
||||
# views.py
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from rest_framework.exceptions import NotFound, APIException
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.routers import DefaultRouter
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from .utils import get_full_resource_map
|
||||
|
||||
router = DefaultRouter()
|
||||
|
||||
BASE_URL = "http://localhost:8080"
|
||||
|
||||
|
||||
class ProxyMixin(APIView):
|
||||
"""
|
||||
通用资源代理 API,支持动态路径、自动文档生成
|
||||
"""
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def _build_url(self, resource_name: str, pk: str = None, query_params=None):
|
||||
resource_map = get_full_resource_map()
|
||||
resource = resource_map.get(resource_name)
|
||||
if not resource:
|
||||
raise NotFound(f"Unknown resource: {resource_name}")
|
||||
|
||||
base_path = resource['path']
|
||||
if pk:
|
||||
base_path += f"{pk}/"
|
||||
|
||||
if query_params:
|
||||
base_path += f"?{urlencode(query_params)}"
|
||||
|
||||
return f"{BASE_URL}{base_path}"
|
||||
|
||||
def _proxy(self, request, resource: str, pk: str = None, action='list'):
|
||||
method = request.method.lower()
|
||||
if method not in ['get', 'post', 'put', 'patch', 'delete', 'options']:
|
||||
raise APIException("Unsupported method")
|
||||
|
||||
if not resource or resource == '{resource}':
|
||||
if request.data:
|
||||
resource = request.data.get('resource')
|
||||
|
||||
query_params = request.query_params.dict()
|
||||
if action == 'list':
|
||||
query_params['limit'] = 10
|
||||
|
||||
url = self._build_url(resource, pk, query_params)
|
||||
headers = {k: v for k, v in request.headers.items() if k.lower() != 'host'}
|
||||
cookies = request.COOKIES
|
||||
body = request.body if method in ['post', 'put', 'patch'] else None
|
||||
|
||||
try:
|
||||
resp = requests.request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
data=body,
|
||||
timeout=10,
|
||||
)
|
||||
content_type = resp.headers.get('Content-Type', '')
|
||||
if 'application/json' in content_type:
|
||||
data = resp.json()
|
||||
else:
|
||||
data = resp.text # 或者 bytes:resp.content
|
||||
|
||||
return Response(data=data, status=resp.status_code)
|
||||
except requests.RequestException as e:
|
||||
raise APIException(f"Proxy request failed: {str(e)}")
|
||||
@@ -1,45 +0,0 @@
|
||||
# views.py
|
||||
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework import serializers
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.routers import DefaultRouter
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from .utils import get_full_resource_map
|
||||
|
||||
router = DefaultRouter()
|
||||
|
||||
BASE_URL = "http://localhost:8080"
|
||||
|
||||
|
||||
class ResourceTypeResourceSerializer(serializers.Serializer):
|
||||
name = serializers.CharField()
|
||||
path = serializers.CharField()
|
||||
app = serializers.CharField()
|
||||
verbose_name = serializers.CharField()
|
||||
description = serializers.CharField()
|
||||
|
||||
|
||||
class ResourceTypeListApi(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
@extend_schema(
|
||||
operation_id="get_supported_resources",
|
||||
summary="Get-all-support-resources",
|
||||
description="Get all support resources, name, path, verbose_name description",
|
||||
responses={200: ResourceTypeResourceSerializer(many=True)}, # Specify the response serializer
|
||||
)
|
||||
def get(self, request):
|
||||
result = []
|
||||
resource_map = get_full_resource_map()
|
||||
for name, desc in resource_map.items():
|
||||
desc = resource_map.get(name, {})
|
||||
resource = {
|
||||
"name": name,
|
||||
**desc,
|
||||
"path": f'/api/v1/resources/{name}/',
|
||||
}
|
||||
result.append(resource)
|
||||
return Response(result)
|
||||
@@ -1,128 +0,0 @@
|
||||
# views.py
|
||||
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import Dict
|
||||
|
||||
from django.urls import URLPattern
|
||||
from django.urls import URLResolver
|
||||
from drf_spectacular.utils import OpenApiParameter
|
||||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
router = DefaultRouter()
|
||||
|
||||
BASE_URL = "http://localhost:8080"
|
||||
|
||||
|
||||
def clean_path(path: str) -> str:
|
||||
"""
|
||||
清理掉 DRF 自动生成的正则格式内容,让其变成普通 RESTful URL path。
|
||||
"""
|
||||
|
||||
# 去掉格式后缀匹配: \.(?P<format>xxx)
|
||||
path = re.sub(r'\\\.\(\?P<format>[^)]+\)', '', path)
|
||||
|
||||
# 去掉括号式格式匹配
|
||||
path = re.sub(r'\(\?P<format>[^)]+\)', '', path)
|
||||
|
||||
# 移除 DRF 中正则参数的部分 (?P<param>pattern)
|
||||
path = re.sub(r'\(\?P<\w+>[^)]+\)', '{param}', path)
|
||||
|
||||
# 如果有多个括号包裹的正则(比如前缀路径),去掉可选部分包装
|
||||
path = re.sub(r'\(\(([^)]+)\)\?\)', r'\1', path) # ((...))? => ...
|
||||
|
||||
# 去掉中间和两边的 ^ 和 $
|
||||
path = path.replace('^', '').replace('$', '')
|
||||
|
||||
# 去掉尾部 ?/
|
||||
path = re.sub(r'\?/?$', '', path)
|
||||
|
||||
# 去掉反斜杠
|
||||
path = path.replace('\\', '')
|
||||
|
||||
# 替换多重斜杠
|
||||
path = re.sub(r'/+', '/', path)
|
||||
|
||||
# 添加开头斜杠,移除多余空格
|
||||
path = path.strip()
|
||||
if not path.startswith('/'):
|
||||
path = '/' + path
|
||||
if not path.endswith('/'):
|
||||
path += '/'
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def extract_resource_paths(urlpatterns, prefix='/api/v1/') -> Dict[str, Dict[str, str]]:
|
||||
resource_map = {}
|
||||
|
||||
for pattern in urlpatterns:
|
||||
if isinstance(pattern, URLResolver):
|
||||
nested_prefix = prefix + str(pattern.pattern)
|
||||
resource_map.update(extract_resource_paths(pattern.url_patterns, nested_prefix))
|
||||
|
||||
elif isinstance(pattern, URLPattern):
|
||||
callback = pattern.callback
|
||||
actions = getattr(callback, 'actions', {})
|
||||
if not actions:
|
||||
continue
|
||||
|
||||
if 'get' in actions and actions['get'] == 'list':
|
||||
path = clean_path(prefix + str(pattern.pattern))
|
||||
|
||||
# 尝试获取资源名称
|
||||
name = pattern.name
|
||||
if name and name.endswith('-list'):
|
||||
resource = name[:-5]
|
||||
else:
|
||||
resource = path.strip('/').split('/')[-1]
|
||||
|
||||
# 不强行加 s,资源名保持原状即可
|
||||
resource = resource if resource.endswith('s') else resource + 's'
|
||||
|
||||
# 获取 View 类和 model 的 verbose_name
|
||||
view_cls = getattr(callback, 'cls', None)
|
||||
model = None
|
||||
|
||||
if view_cls:
|
||||
queryset = getattr(view_cls, 'queryset', None)
|
||||
if queryset is not None:
|
||||
model = getattr(queryset, 'model', None)
|
||||
else:
|
||||
# 有些 View 用 get_queryset()
|
||||
try:
|
||||
instance = view_cls()
|
||||
qs = instance.get_queryset()
|
||||
model = getattr(qs, 'model', None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not model:
|
||||
continue
|
||||
|
||||
app = str(getattr(model._meta, 'app_label', ''))
|
||||
verbose_name = str(getattr(model._meta, 'verbose_name', ''))
|
||||
resource_map[resource] = {
|
||||
'path': path,
|
||||
'app': app,
|
||||
'verbose_name': verbose_name,
|
||||
'description': model.__doc__.__str__()
|
||||
}
|
||||
|
||||
print("Extracted resource paths:", list(resource_map.keys()))
|
||||
return resource_map
|
||||
|
||||
|
||||
def param_dic_to_param(d):
|
||||
return OpenApiParameter(
|
||||
name=d['name'], location=d['in'],
|
||||
description=d['description'], type=d['type'], required=d.get('required', False)
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_full_resource_map():
|
||||
from apps.jumpserver.urls import resource_api
|
||||
resource_map = extract_resource_paths(resource_api)
|
||||
print("Building URL for resource:", resource_map)
|
||||
return resource_map
|
||||
@@ -701,15 +701,7 @@ class Config(dict):
|
||||
'CHAT_AI_ENABLED': False,
|
||||
'CHAT_AI_METHOD': 'api',
|
||||
'CHAT_AI_EMBED_URL': '',
|
||||
'CHAT_AI_TYPE': 'gpt',
|
||||
'GPT_BASE_URL': '',
|
||||
'GPT_API_KEY': '',
|
||||
'GPT_PROXY': '',
|
||||
'GPT_MODEL': 'gpt-4o-mini',
|
||||
'DEEPSEEK_BASE_URL': '',
|
||||
'DEEPSEEK_API_KEY': '',
|
||||
'DEEPSEEK_PROXY': '',
|
||||
'DEEPSEEK_MODEL': 'deepseek-chat',
|
||||
'CHAT_AI_PROVIDERS': [],
|
||||
'VIRTUAL_APP_ENABLED': False,
|
||||
|
||||
'FILE_UPLOAD_SIZE_LIMIT_MB': 200,
|
||||
|
||||
@@ -5,6 +5,7 @@ from rest_framework.pagination import LimitOffsetPagination
|
||||
|
||||
class MaxLimitOffsetPagination(LimitOffsetPagination):
|
||||
max_limit = settings.MAX_PAGE_SIZE
|
||||
default_limit = settings.DEFAULT_PAGE_SIZE
|
||||
|
||||
def get_count(self, queryset):
|
||||
try:
|
||||
|
||||
@@ -241,15 +241,7 @@ ASSET_SIZE = 'small'
|
||||
CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED
|
||||
CHAT_AI_METHOD = CONFIG.CHAT_AI_METHOD
|
||||
CHAT_AI_EMBED_URL = CONFIG.CHAT_AI_EMBED_URL
|
||||
CHAT_AI_TYPE = CONFIG.CHAT_AI_TYPE
|
||||
GPT_BASE_URL = CONFIG.GPT_BASE_URL
|
||||
GPT_API_KEY = CONFIG.GPT_API_KEY
|
||||
GPT_PROXY = CONFIG.GPT_PROXY
|
||||
GPT_MODEL = CONFIG.GPT_MODEL
|
||||
DEEPSEEK_BASE_URL = CONFIG.DEEPSEEK_BASE_URL
|
||||
DEEPSEEK_API_KEY = CONFIG.DEEPSEEK_API_KEY
|
||||
DEEPSEEK_PROXY = CONFIG.DEEPSEEK_PROXY
|
||||
DEEPSEEK_MODEL = CONFIG.DEEPSEEK_MODEL
|
||||
CHAT_AI_DEFAULT_PROVIDER = CONFIG.CHAT_AI_DEFAULT_PROVIDER
|
||||
|
||||
VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED
|
||||
|
||||
@@ -268,4 +260,6 @@ LOKI_BASE_URL = CONFIG.LOKI_BASE_URL
|
||||
TOOL_USER_ENABLED = CONFIG.TOOL_USER_ENABLED
|
||||
|
||||
SUGGESTION_LIMIT = CONFIG.SUGGESTION_LIMIT
|
||||
MCP_ENABLED = CONFIG.MCP_ENABLED
|
||||
MCP_ENABLED = CONFIG.MCP_ENABLED
|
||||
CHAT_AI_PROVIDERS = CONFIG.CHAT_AI_PROVIDERS
|
||||
|
||||
|
||||
@@ -85,6 +85,7 @@ SPECTACULAR_SETTINGS = {
|
||||
'jumpserver.views.schema.LabeledChoiceFieldExtension',
|
||||
'jumpserver.views.schema.BitChoicesFieldExtension',
|
||||
'jumpserver.views.schema.LabelRelatedFieldExtension',
|
||||
'jumpserver.views.schema.DateTimeFieldExtension',
|
||||
],
|
||||
'SECURITY': [{'Bearer': []}],
|
||||
}
|
||||
|
||||
@@ -37,12 +37,6 @@ api_v1 = resource_api + [
|
||||
path('prometheus/metrics/', api.PrometheusMetricsApi.as_view()),
|
||||
path('search/', api.GlobalSearchView.as_view()),
|
||||
]
|
||||
if settings.MCP_ENABLED:
|
||||
api_v1.extend([
|
||||
path('resources/', api.ResourceTypeListApi.as_view(), name='resource-list'),
|
||||
path('resources/<str:resource>/', api.ResourceListApi.as_view()),
|
||||
path('resources/<str:resource>/<str:pk>/', api.ResourceDetailApi.as_view()),
|
||||
])
|
||||
|
||||
app_view_patterns = [
|
||||
path('auth/', include('authentication.urls.view_urls'), name='auth'),
|
||||
|
||||
@@ -1,421 +0,0 @@
|
||||
import re
|
||||
|
||||
from drf_spectacular.openapi import AutoSchema
|
||||
from drf_spectacular.generators import SchemaGenerator
|
||||
|
||||
|
||||
class CustomSchemaGenerator(SchemaGenerator):
|
||||
from_mcp = False
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
self.from_mcp = request.query_params.get('mcp') or request.path.endswith('swagger.json')
|
||||
return super().get_schema(request, public)
|
||||
|
||||
|
||||
class CustomAutoSchema(AutoSchema):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.from_mcp = kwargs.get('from_mcp', False)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def map_parsers(self):
|
||||
return ['application/json']
|
||||
|
||||
def map_renderers(self, *args, **kwargs):
|
||||
return ['application/json']
|
||||
|
||||
def get_tags(self):
|
||||
operation_keys = self._tokenize_path()
|
||||
if len(operation_keys) == 1:
|
||||
return []
|
||||
tags = ['_'.join(operation_keys[:2])]
|
||||
return tags
|
||||
|
||||
def get_operation(self, path, *args, **kwargs):
|
||||
if path.endswith('render-to-json/'):
|
||||
return None
|
||||
# if not path.startswith('/api/v1/users'):
|
||||
# return None
|
||||
operation = super().get_operation(path, *args, **kwargs)
|
||||
if not operation:
|
||||
return operation
|
||||
|
||||
if not operation.get('summary', ''):
|
||||
operation['summary'] = operation.get('operationId')
|
||||
|
||||
return operation
|
||||
|
||||
def get_operation_id(self):
|
||||
tokenized_path = self._tokenize_path()
|
||||
# replace dashes as they can be problematic later in code generation
|
||||
tokenized_path = [t.replace('-', '_') for t in tokenized_path]
|
||||
|
||||
action = ''
|
||||
if hasattr(self.view, 'action'):
|
||||
action = self.view.action
|
||||
|
||||
if not action:
|
||||
if self.method == 'GET' and self._is_list_view():
|
||||
action = 'list'
|
||||
else:
|
||||
action = self.method_mapping[self.method.lower()]
|
||||
|
||||
if action == "bulk_destroy":
|
||||
action = "bulk_delete"
|
||||
|
||||
if not tokenized_path:
|
||||
tokenized_path.append('root')
|
||||
|
||||
if re.search(r'<drf_format_suffix\w*:\w+>', self.path_regex):
|
||||
tokenized_path.append('formatted')
|
||||
|
||||
return '_'.join(tokenized_path + [action])
|
||||
|
||||
def get_filter_parameters(self):
|
||||
if not self.should_filter():
|
||||
return []
|
||||
|
||||
fields = []
|
||||
if hasattr(self.view, 'get_filter_backends'):
|
||||
backends = self.view.get_filter_backends()
|
||||
elif hasattr(self.view, 'filter_backends'):
|
||||
backends = self.view.filter_backends
|
||||
else:
|
||||
backends = []
|
||||
for filter_backend in backends:
|
||||
fields += self.probe_inspectors(
|
||||
self.filter_inspectors, 'get_filter_parameters', filter_backend()
|
||||
) or []
|
||||
return fields
|
||||
|
||||
def get_auth(self):
|
||||
return [{'Bearer': []}]
|
||||
|
||||
def get_operation_security(self):
|
||||
"""
|
||||
重写操作安全配置,统一使用 Bearer token
|
||||
"""
|
||||
return [{'Bearer': []}]
|
||||
|
||||
def get_components_security_schemes(self):
|
||||
"""
|
||||
重写安全方案定义,避免认证类解析错误
|
||||
"""
|
||||
return {
|
||||
'Bearer': {
|
||||
'type': 'http',
|
||||
'scheme': 'bearer',
|
||||
'bearerFormat': 'JWT',
|
||||
'description': 'JWT token for API authentication'
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def exclude_some_paths(path):
|
||||
# 这里可以对 paths 进行处理
|
||||
excludes = [
|
||||
'/report/', '/render-to-json/', '/suggestions/',
|
||||
'executions', 'automations', 'change-secret-records',
|
||||
'change-secret-dashboard', '/copy-to-assets/',
|
||||
'/move-to-assets/', 'dashboard', 'index', 'countries',
|
||||
'/resources/cache/', 'profile/mfa', 'profile/password',
|
||||
'profile/permissions', 'prometheus', 'constraints'
|
||||
]
|
||||
for p in excludes:
|
||||
if path.find(p) >= 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def exclude_some_app_model(self, path):
|
||||
parts = path.split('/')
|
||||
if len(parts) < 5:
|
||||
return False
|
||||
|
||||
apps = []
|
||||
if self.from_mcp:
|
||||
apps = [
|
||||
'ops', 'tickets', 'authentication',
|
||||
'settings', 'xpack', 'terminal', 'rbac',
|
||||
'notifications', 'promethues', 'acls'
|
||||
]
|
||||
|
||||
app_name = parts[3]
|
||||
if app_name in apps:
|
||||
return True
|
||||
models = []
|
||||
model = parts[4]
|
||||
if self.from_mcp:
|
||||
models = [
|
||||
'users', 'user-groups', 'users-groups-relations', 'assets', 'hosts', 'devices', 'databases',
|
||||
'webs', 'clouds', 'gpts', 'ds', 'customs', 'platforms', 'nodes', 'zones', 'gateways',
|
||||
'protocol-settings', 'labels', 'virtual-accounts', 'gathered-accounts', 'account-templates',
|
||||
'account-template-secrets', 'account-backups', 'account-backup-executions',
|
||||
'change-secret-automations', 'change-secret-executions', 'change-secret-records',
|
||||
'gather-account-automations', 'gather-account-executions', 'push-account-automations',
|
||||
'push-account-executions', 'push-account-records', 'check-account-automations',
|
||||
'check-account-executions', 'account-risks', 'integration-apps', 'asset-permissions',
|
||||
'asset-permissions-users-relations', 'asset-permissions-user-groups-relations',
|
||||
'asset-permissions-assets-relations', 'asset-permissions-nodes-relations', 'terminal-status',
|
||||
'terminals', 'tasks', 'status', 'replay-storages', 'command-storages', 'session-sharing-records',
|
||||
'endpoints', 'endpoint-rules', 'applets', 'applet-hosts', 'applet-publications',
|
||||
'applet-host-deployments', 'virtual-apps', 'app-providers', 'virtual-app-publications',
|
||||
'celery-period-tasks', 'task-executions', 'adhocs', 'playbooks', 'variables', 'ftp-logs',
|
||||
'login-logs', 'operate-logs', 'password-change-logs', 'job-logs', 'jobs', 'user-sessions',
|
||||
'service-access-logs', 'chatai-prompts', 'super-connection-tokens', 'flows',
|
||||
'apply-assets', 'apply-nodes', 'login-acls', 'login-asset-acls', 'command-filter-acls',
|
||||
'command-groups', 'connect-method-acls', 'system-msg-subscriptions', 'roles', 'role-bindings',
|
||||
'system-roles', 'system-role-bindings', 'org-roles', 'org-role-bindings', 'content-types',
|
||||
'labeled-resources', 'account-backup-plans', 'account-check-engines', 'account-secrets',
|
||||
'change-secret', 'integration-applications', 'push-account', 'directories', 'connection-token',
|
||||
'groups', 'accounts', 'resource-types', 'favorite-assets', 'activities', 'platform-automation-methods',
|
||||
]
|
||||
if model in models:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_excluded(self):
|
||||
if self.exclude_some_paths(self.path):
|
||||
return True
|
||||
if self.exclude_some_app_model(self.path):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_operation(self, path, *args, **kwargs):
|
||||
operation = super().get_operation(path, *args, **kwargs)
|
||||
if not operation:
|
||||
return operation
|
||||
|
||||
operation_id = operation.get('operationId')
|
||||
if 'bulk' in operation_id:
|
||||
return None
|
||||
|
||||
if not operation.get('summary', ''):
|
||||
operation['summary'] = operation.get('operationId')
|
||||
|
||||
exclude_operations = [
|
||||
'orgs_orgs_read', 'orgs_orgs_update', 'orgs_orgs_delete',
|
||||
'orgs_orgs_create', 'orgs_orgs_partial_update',
|
||||
]
|
||||
if operation_id in exclude_operations:
|
||||
return None
|
||||
return operation
|
||||
|
||||
# 添加自定义字段的 OpenAPI 扩展
|
||||
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
|
||||
from drf_spectacular.openapi import AutoSchema
|
||||
from drf_spectacular.plumbing import build_basic_type
|
||||
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField, BitChoicesField
|
||||
|
||||
|
||||
class ObjectRelatedFieldExtension(OpenApiSerializerFieldExtension):
|
||||
"""
|
||||
为 ObjectRelatedField 提供 OpenAPI schema
|
||||
"""
|
||||
target_class = ObjectRelatedField
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
field = self.target
|
||||
|
||||
# 获取字段的基本信息
|
||||
field_type = 'array' if field.many else 'object'
|
||||
|
||||
if field_type == 'array':
|
||||
# 如果是多对多关系
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': self._get_openapi_item_schema(field),
|
||||
'description': getattr(field, 'help_text', ''),
|
||||
'title': getattr(field, 'label', ''),
|
||||
}
|
||||
else:
|
||||
# 如果是一对一关系
|
||||
return {
|
||||
'type': 'object',
|
||||
'properties': self._get_openapi_properties_schema(field),
|
||||
'description': getattr(field, 'help_text', ''),
|
||||
'title': getattr(field, 'label', ''),
|
||||
}
|
||||
|
||||
def _get_openapi_item_schema(self, field):
|
||||
"""
|
||||
获取数组项的 OpenAPI schema
|
||||
"""
|
||||
return self._get_openapi_object_schema(field)
|
||||
|
||||
def _get_openapi_object_schema(self, field):
|
||||
"""
|
||||
获取对象的 OpenAPI schema
|
||||
"""
|
||||
properties = {}
|
||||
|
||||
# 动态分析 attrs 中的属性类型
|
||||
for attr in field.attrs:
|
||||
# 尝试从 queryset 的 model 中获取字段信息
|
||||
field_type = self._infer_field_type(field, attr)
|
||||
properties[attr] = {
|
||||
'type': field_type,
|
||||
'description': f'{attr} field'
|
||||
}
|
||||
|
||||
return {
|
||||
'type': 'object',
|
||||
'properties': properties,
|
||||
'required': ['id'] if 'id' in field.attrs else []
|
||||
}
|
||||
|
||||
def _infer_field_type(self, field, attr_name):
|
||||
"""
|
||||
智能推断字段类型
|
||||
"""
|
||||
try:
|
||||
# 如果有 queryset,尝试从 model 中获取字段信息
|
||||
if hasattr(field, 'queryset') and field.queryset is not None:
|
||||
model = field.queryset.model
|
||||
if hasattr(model, '_meta') and hasattr(model._meta, 'fields'):
|
||||
model_field = model._meta.get_field(attr_name)
|
||||
if model_field:
|
||||
return self._map_django_field_type(model_field)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 如果没有 queryset 或无法获取字段信息,使用启发式规则
|
||||
return self._heuristic_field_type(attr_name)
|
||||
|
||||
def _map_django_field_type(self, model_field):
|
||||
"""
|
||||
将 Django 字段类型映射到 OpenAPI 类型
|
||||
"""
|
||||
field_type = type(model_field).__name__
|
||||
|
||||
# 整数类型
|
||||
if 'Integer' in field_type or 'BigInteger' in field_type or 'SmallInteger' in field_type:
|
||||
return 'integer'
|
||||
# 浮点数类型
|
||||
elif 'Float' in field_type or 'Decimal' in field_type:
|
||||
return 'number'
|
||||
# 布尔类型
|
||||
elif 'Boolean' in field_type:
|
||||
return 'boolean'
|
||||
# 日期时间类型
|
||||
elif 'DateTime' in field_type or 'Date' in field_type or 'Time' in field_type:
|
||||
return 'string'
|
||||
# 文件类型
|
||||
elif 'File' in field_type or 'Image' in field_type:
|
||||
return 'string'
|
||||
# 其他类型默认为字符串
|
||||
else:
|
||||
return 'string'
|
||||
|
||||
def _heuristic_field_type(self, attr_name):
|
||||
"""
|
||||
启发式推断字段类型
|
||||
"""
|
||||
# 基于属性名的启发式规则
|
||||
|
||||
if attr_name in ['is_active', 'enabled', 'visible'] or attr_name.startswith('is_'):
|
||||
return 'boolean'
|
||||
elif attr_name in ['count', 'number', 'size', 'amount']:
|
||||
return 'integer'
|
||||
elif attr_name in ['price', 'rate', 'percentage']:
|
||||
return 'number'
|
||||
else:
|
||||
# 默认返回字符串类型
|
||||
return 'string'
|
||||
|
||||
def _get_openapi_properties_schema(self, field):
|
||||
"""
|
||||
获取对象属性的 OpenAPI schema
|
||||
"""
|
||||
return self._get_openapi_object_schema(field)['properties']
|
||||
|
||||
|
||||
class LabeledChoiceFieldExtension(OpenApiSerializerFieldExtension):
|
||||
"""
|
||||
为 LabeledChoiceField 提供 OpenAPI schema
|
||||
"""
|
||||
target_class = LabeledChoiceField
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
field = self.target
|
||||
|
||||
if getattr(field, 'many', False):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'value': {'type': 'string'},
|
||||
'label': {'type': 'string'}
|
||||
}
|
||||
},
|
||||
'description': getattr(field, 'help_text', ''),
|
||||
'title': getattr(field, 'label', ''),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'value': {'type': 'string'},
|
||||
'label': {'type': 'string'}
|
||||
},
|
||||
'description': getattr(field, 'help_text', ''),
|
||||
'title': getattr(field, 'label', ''),
|
||||
}
|
||||
|
||||
|
||||
class BitChoicesFieldExtension(OpenApiSerializerFieldExtension):
|
||||
"""
|
||||
为 BitChoicesField 提供 OpenAPI schema
|
||||
"""
|
||||
target_class = BitChoicesField
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
field = self.target
|
||||
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'value': {'type': 'string'},
|
||||
'label': {'type': 'string'}
|
||||
}
|
||||
},
|
||||
'description': getattr(field, 'help_text', ''),
|
||||
'title': getattr(field, 'label', ''),
|
||||
}
|
||||
|
||||
|
||||
class LabelRelatedFieldExtension(OpenApiSerializerFieldExtension):
|
||||
"""
|
||||
为 LabelRelatedField 提供 OpenAPI schema
|
||||
"""
|
||||
target_class = 'common.serializers.fields.LabelRelatedField'
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
field = self.target
|
||||
|
||||
# LabelRelatedField 返回一个包含 id, name, value, color 的对象
|
||||
return {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'id': {
|
||||
'type': 'string',
|
||||
'description': 'Label ID'
|
||||
},
|
||||
'name': {
|
||||
'type': 'string',
|
||||
'description': 'Label name'
|
||||
},
|
||||
'value': {
|
||||
'type': 'string',
|
||||
'description': 'Label value'
|
||||
},
|
||||
'color': {
|
||||
'type': 'string',
|
||||
'description': 'Label color'
|
||||
}
|
||||
},
|
||||
'required': ['id', 'name', 'value'],
|
||||
'description': getattr(field, 'help_text', 'Label information'),
|
||||
'title': getattr(field, 'label', 'Label'),
|
||||
}
|
||||
2
apps/jumpserver/views/schema/__init__.py
Normal file
2
apps/jumpserver/views/schema/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .extension import *
|
||||
from .schema import *
|
||||
263
apps/jumpserver/views/schema/extension.py
Normal file
263
apps/jumpserver/views/schema/extension.py
Normal file
@@ -0,0 +1,263 @@
|
||||
|
||||
# 添加自定义字段的 OpenAPI 扩展
|
||||
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
|
||||
from drf_spectacular.openapi import AutoSchema
|
||||
from drf_spectacular.plumbing import build_basic_type
|
||||
from rest_framework import serializers
|
||||
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField, BitChoicesField
|
||||
|
||||
|
||||
__all__ = [
|
||||
'ObjectRelatedFieldExtension', 'LabeledChoiceFieldExtension',
|
||||
'BitChoicesFieldExtension', 'LabelRelatedFieldExtension',
|
||||
'DateTimeFieldExtension'
|
||||
]
|
||||
|
||||
|
||||
class ObjectRelatedFieldExtension(OpenApiSerializerFieldExtension):
|
||||
"""
|
||||
为 ObjectRelatedField 提供 OpenAPI schema
|
||||
"""
|
||||
target_class = ObjectRelatedField
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
field = self.target
|
||||
|
||||
# 获取字段的基本信息
|
||||
field_type = 'array' if field.many else 'object'
|
||||
|
||||
if field_type == 'array':
|
||||
# 如果是多对多关系
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': self._get_openapi_item_schema(field),
|
||||
'description': getattr(field, 'help_text', ''),
|
||||
'title': getattr(field, 'label', ''),
|
||||
}
|
||||
else:
|
||||
# 如果是一对一关系
|
||||
return {
|
||||
'type': 'object',
|
||||
'properties': self._get_openapi_properties_schema(field),
|
||||
'description': getattr(field, 'help_text', ''),
|
||||
'title': getattr(field, 'label', ''),
|
||||
}
|
||||
|
||||
def _get_openapi_item_schema(self, field):
|
||||
"""
|
||||
获取数组项的 OpenAPI schema
|
||||
"""
|
||||
return self._get_openapi_object_schema(field)
|
||||
|
||||
def _get_openapi_object_schema(self, field):
|
||||
"""
|
||||
获取对象的 OpenAPI schema
|
||||
"""
|
||||
properties = {}
|
||||
|
||||
# 动态分析 attrs 中的属性类型
|
||||
for attr in field.attrs:
|
||||
# 尝试从 queryset 的 model 中获取字段信息
|
||||
field_type = self._infer_field_type(field, attr)
|
||||
properties[attr] = {
|
||||
'type': field_type,
|
||||
'description': f'{attr} field'
|
||||
}
|
||||
|
||||
return {
|
||||
'type': 'object',
|
||||
'properties': properties,
|
||||
'required': ['id'] if 'id' in field.attrs else []
|
||||
}
|
||||
|
||||
def _infer_field_type(self, field, attr_name):
|
||||
"""
|
||||
智能推断字段类型
|
||||
"""
|
||||
try:
|
||||
# 如果有 queryset,尝试从 model 中获取字段信息
|
||||
if hasattr(field, 'queryset') and field.queryset is not None:
|
||||
model = field.queryset.model
|
||||
if hasattr(model, '_meta') and hasattr(model._meta, 'fields'):
|
||||
model_field = model._meta.get_field(attr_name)
|
||||
if model_field:
|
||||
return self._map_django_field_type(model_field)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 如果没有 queryset 或无法获取字段信息,使用启发式规则
|
||||
return self._heuristic_field_type(attr_name)
|
||||
|
||||
def _map_django_field_type(self, model_field):
|
||||
"""
|
||||
将 Django 字段类型映射到 OpenAPI 类型
|
||||
"""
|
||||
field_type = type(model_field).__name__
|
||||
|
||||
# 整数类型
|
||||
if 'Integer' in field_type or 'BigInteger' in field_type or 'SmallInteger' in field_type or 'AutoField' in field_type:
|
||||
return 'integer'
|
||||
# 浮点数类型
|
||||
elif 'Float' in field_type or 'Decimal' in field_type:
|
||||
return 'number'
|
||||
# 布尔类型
|
||||
elif 'Boolean' in field_type:
|
||||
return 'boolean'
|
||||
# 日期时间类型
|
||||
elif 'DateTime' in field_type or 'Date' in field_type or 'Time' in field_type:
|
||||
return 'string'
|
||||
# 文件类型
|
||||
elif 'File' in field_type or 'Image' in field_type:
|
||||
return 'string'
|
||||
# 其他类型默认为字符串
|
||||
else:
|
||||
return 'string'
|
||||
|
||||
def _heuristic_field_type(self, attr_name):
|
||||
"""
|
||||
启发式推断字段类型
|
||||
"""
|
||||
# 基于属性名的启发式规则
|
||||
|
||||
if attr_name in ['is_active', 'enabled', 'visible'] or attr_name.startswith('is_'):
|
||||
return 'boolean'
|
||||
elif attr_name in ['count', 'number', 'size', 'amount']:
|
||||
return 'integer'
|
||||
elif attr_name in ['price', 'rate', 'percentage']:
|
||||
return 'number'
|
||||
else:
|
||||
# 默认返回字符串类型
|
||||
return 'string'
|
||||
|
||||
def _get_openapi_properties_schema(self, field):
|
||||
"""
|
||||
获取对象属性的 OpenAPI schema
|
||||
"""
|
||||
return self._get_openapi_object_schema(field)['properties']
|
||||
|
||||
|
||||
class LabeledChoiceFieldExtension(OpenApiSerializerFieldExtension):
|
||||
"""
|
||||
为 LabeledChoiceField 提供 OpenAPI schema
|
||||
"""
|
||||
target_class = LabeledChoiceField
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
field = self.target
|
||||
|
||||
if getattr(field, 'many', False):
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'value': {'type': 'string'},
|
||||
'label': {'type': 'string'}
|
||||
}
|
||||
},
|
||||
'description': getattr(field, 'help_text', ''),
|
||||
'title': getattr(field, 'label', ''),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'value': {'type': 'string'},
|
||||
'label': {'type': 'string'}
|
||||
},
|
||||
'description': getattr(field, 'help_text', ''),
|
||||
'title': getattr(field, 'label', ''),
|
||||
}
|
||||
|
||||
|
||||
class BitChoicesFieldExtension(OpenApiSerializerFieldExtension):
|
||||
"""
|
||||
为 BitChoicesField 提供 OpenAPI schema
|
||||
"""
|
||||
target_class = BitChoicesField
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
field = self.target
|
||||
|
||||
return {
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'value': {'type': 'string'},
|
||||
'label': {'type': 'string'}
|
||||
}
|
||||
},
|
||||
'description': getattr(field, 'help_text', ''),
|
||||
'title': getattr(field, 'label', ''),
|
||||
}
|
||||
|
||||
|
||||
class LabelRelatedFieldExtension(OpenApiSerializerFieldExtension):
|
||||
"""
|
||||
为 LabelRelatedField 提供 OpenAPI schema
|
||||
"""
|
||||
target_class = 'common.serializers.fields.LabelRelatedField'
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
field = self.target
|
||||
|
||||
# LabelRelatedField 返回一个包含 id, name, value, color 的对象
|
||||
return {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'id': {
|
||||
'type': 'string',
|
||||
'description': 'Label ID'
|
||||
},
|
||||
'name': {
|
||||
'type': 'string',
|
||||
'description': 'Label name'
|
||||
},
|
||||
'value': {
|
||||
'type': 'string',
|
||||
'description': 'Label value'
|
||||
},
|
||||
'color': {
|
||||
'type': 'string',
|
||||
'description': 'Label color'
|
||||
}
|
||||
},
|
||||
'required': ['id', 'name', 'value'],
|
||||
'description': getattr(field, 'help_text', 'Label information'),
|
||||
'title': getattr(field, 'label', 'Label'),
|
||||
}
|
||||
|
||||
|
||||
class DateTimeFieldExtension(OpenApiSerializerFieldExtension):
|
||||
"""
|
||||
为 DateTimeField 提供自定义 OpenAPI schema
|
||||
修正 datetime 字段格式,使其符合实际返回格式 '%Y/%m/%d %H:%M:%S %z'
|
||||
而不是标准的 ISO 8601 格式 (date-time)
|
||||
"""
|
||||
target_class = serializers.DateTimeField
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
field = self.target
|
||||
|
||||
# 获取字段的描述信息,确保始终是字符串类型
|
||||
help_text = getattr(field, 'help_text', None) or ''
|
||||
description = help_text if isinstance(help_text, str) else ''
|
||||
|
||||
# 添加格式说明
|
||||
format_desc = 'Format: YYYY/MM/DD HH:MM:SS +TZ (e.g., 2023/10/01 12:00:00 +0800)'
|
||||
if description:
|
||||
description = f'{description} {format_desc}'
|
||||
else:
|
||||
description = format_desc
|
||||
|
||||
# 返回字符串类型,不包含 format: date-time
|
||||
# 因为实际返回格式是 '%Y/%m/%d %H:%M:%S %z',不是标准的 ISO 8601
|
||||
schema = {
|
||||
'type': 'string',
|
||||
'description': description,
|
||||
'title': getattr(field, 'label', '') or '',
|
||||
'example': '2023/10/01 12:00:00 +0800',
|
||||
}
|
||||
|
||||
return schema
|
||||
206
apps/jumpserver/views/schema/schema.py
Normal file
206
apps/jumpserver/views/schema/schema.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import re
|
||||
|
||||
from drf_spectacular.openapi import AutoSchema
|
||||
from drf_spectacular.generators import SchemaGenerator
|
||||
|
||||
__all__ = [
|
||||
'CustomSchemaGenerator', 'CustomAutoSchema'
|
||||
]
|
||||
|
||||
|
||||
class CustomSchemaGenerator(SchemaGenerator):
|
||||
from_mcp = False
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
self.from_mcp = request.query_params.get('mcp') or request.path.endswith('swagger.json')
|
||||
return super().get_schema(request, public)
|
||||
|
||||
|
||||
class CustomAutoSchema(AutoSchema):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.from_mcp = True
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def map_parsers(self):
|
||||
return ['application/json']
|
||||
|
||||
def map_renderers(self, *args, **kwargs):
|
||||
return ['application/json']
|
||||
|
||||
def get_tags(self):
|
||||
operation_keys = self._tokenize_path()
|
||||
if len(operation_keys) == 1:
|
||||
return []
|
||||
tags = ['_'.join(operation_keys[:2])]
|
||||
return tags
|
||||
|
||||
def get_operation_id(self):
|
||||
tokenized_path = self._tokenize_path()
|
||||
# replace dashes as they can be problematic later in code generation
|
||||
tokenized_path = [t.replace('-', '_') for t in tokenized_path]
|
||||
|
||||
action = ''
|
||||
if hasattr(self.view, 'action'):
|
||||
action = self.view.action
|
||||
|
||||
if not action:
|
||||
if self.method == 'GET' and self._is_list_view():
|
||||
action = 'list'
|
||||
else:
|
||||
action = self.method_mapping[self.method.lower()]
|
||||
|
||||
if action == "bulk_destroy":
|
||||
action = "bulk_delete"
|
||||
|
||||
if not tokenized_path:
|
||||
tokenized_path.append('root')
|
||||
|
||||
if re.search(r'<drf_format_suffix\w*:\w+>', self.path_regex):
|
||||
tokenized_path.append('formatted')
|
||||
|
||||
return '_'.join(tokenized_path + [action])
|
||||
|
||||
def get_filter_parameters(self):
|
||||
if not self.should_filter():
|
||||
return []
|
||||
|
||||
fields = []
|
||||
if hasattr(self.view, 'get_filter_backends'):
|
||||
backends = self.view.get_filter_backends()
|
||||
elif hasattr(self.view, 'filter_backends'):
|
||||
backends = self.view.filter_backends
|
||||
else:
|
||||
backends = []
|
||||
for filter_backend in backends:
|
||||
fields += self.probe_inspectors(
|
||||
self.filter_inspectors, 'get_filter_parameters', filter_backend()
|
||||
) or []
|
||||
return fields
|
||||
|
||||
def get_auth(self):
|
||||
return [{'Bearer': []}]
|
||||
|
||||
def get_operation_security(self):
|
||||
"""
|
||||
重写操作安全配置,统一使用 Bearer token
|
||||
"""
|
||||
return [{'Bearer': []}]
|
||||
|
||||
def get_components_security_schemes(self):
|
||||
"""
|
||||
重写安全方案定义,避免认证类解析错误
|
||||
"""
|
||||
return {
|
||||
'Bearer': {
|
||||
'type': 'http',
|
||||
'scheme': 'bearer',
|
||||
'bearerFormat': 'JWT',
|
||||
'description': 'JWT token for API authentication'
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def exclude_some_paths(path):
|
||||
# 这里可以对 paths 进行处理
|
||||
excludes = [
|
||||
'/report/', '/render-to-json/', '/suggestions/',
|
||||
'executions', 'automations', 'change-secret-records',
|
||||
'change-secret-dashboard', '/copy-to-assets/',
|
||||
'/move-to-assets/', 'dashboard', 'index', 'countries',
|
||||
'/resources/cache/', 'profile/mfa', 'profile/password',
|
||||
'profile/permissions', 'prometheus', 'constraints',
|
||||
'/api/swagger.json', '/api/swagger.yaml',
|
||||
]
|
||||
for p in excludes:
|
||||
if path.find(p) >= 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def exclude_some_models(self, model):
|
||||
models = []
|
||||
if self.from_mcp:
|
||||
models = [
|
||||
'users', 'user-groups',
|
||||
'assets', 'hosts', 'devices', 'databases',
|
||||
'webs', 'clouds', 'ds', 'platforms',
|
||||
'nodes', 'zones', 'labels',
|
||||
'accounts', 'account-templates',
|
||||
'asset-permissions',
|
||||
]
|
||||
if models and model in models:
|
||||
return False
|
||||
return True
|
||||
|
||||
def exclude_some_apps(self, app):
|
||||
apps = []
|
||||
if self.from_mcp:
|
||||
apps = [
|
||||
'users', 'assets', 'accounts',
|
||||
'perms', 'labels',
|
||||
]
|
||||
if apps and app in apps:
|
||||
return False
|
||||
return True
|
||||
|
||||
def exclude_some_app_model(self, path):
|
||||
parts = path.split('/')
|
||||
if len(parts) < 5 :
|
||||
return True
|
||||
|
||||
if len(parts) == 7 and parts[5] != "{id}":
|
||||
return True
|
||||
|
||||
if len(parts) > 7:
|
||||
return True
|
||||
|
||||
app_name = parts[3]
|
||||
if self.exclude_some_apps(app_name):
|
||||
return True
|
||||
|
||||
if self.exclude_some_models(parts[4]):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_excluded(self):
|
||||
if self.exclude_some_paths(self.path):
|
||||
return True
|
||||
|
||||
if self.exclude_some_app_model(self.path):
|
||||
return True
|
||||
return False
|
||||
|
||||
def exclude_some_operations(self, operation_id):
|
||||
exclude_operations = [
|
||||
'orgs_orgs_read', 'orgs_orgs_update', 'orgs_orgs_delete',
|
||||
'orgs_orgs_create', 'orgs_orgs_partial_update',
|
||||
]
|
||||
if operation_id in exclude_operations:
|
||||
return True
|
||||
|
||||
if 'bulk' in operation_id:
|
||||
return True
|
||||
|
||||
if 'destroy' in operation_id:
|
||||
return True
|
||||
|
||||
if 'update' in operation_id and 'partial' not in operation_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_operation(self, path, *args, **kwargs):
|
||||
operation = super().get_operation(path, *args, **kwargs)
|
||||
if not operation:
|
||||
return operation
|
||||
|
||||
operation_id = operation.get('operationId')
|
||||
if self.exclude_some_operations(operation_id):
|
||||
return None
|
||||
|
||||
if not operation.get('summary', ''):
|
||||
operation['summary'] = operation.get('operationId')
|
||||
|
||||
# if self.is_excluded():
|
||||
# return None
|
||||
|
||||
return operation
|
||||
@@ -37,11 +37,11 @@ class SchemeMixin:
|
||||
}
|
||||
return Response(schema)
|
||||
|
||||
@method_decorator(cache_page(60 * 5,), name="dispatch")
|
||||
# @method_decorator(cache_page(60 * 5,), name="dispatch")
|
||||
class JsonApi(SchemeMixin, SpectacularJSONAPIView):
|
||||
pass
|
||||
|
||||
@method_decorator(cache_page(60 * 5,), name="dispatch")
|
||||
# @method_decorator(cache_page(60 * 5,), name="dispatch")
|
||||
class YamlApi(SchemeMixin, SpectacularYAMLAPIView):
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,98 +1,10 @@
|
||||
import httpx
|
||||
import openai
|
||||
from django.conf import settings
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import status
|
||||
from rest_framework.generics import GenericAPIView
|
||||
from rest_framework.response import Response
|
||||
|
||||
from common.api import JMSModelViewSet
|
||||
from common.permissions import IsValidUser, OnlySuperUser
|
||||
from .. import serializers
|
||||
from ..const import ChatAITypeChoices
|
||||
from ..models import ChatPrompt
|
||||
from ..prompt import DefaultChatPrompt
|
||||
|
||||
|
||||
class ChatAITestingAPI(GenericAPIView):
|
||||
serializer_class = serializers.ChatAISettingSerializer
|
||||
rbac_perms = {
|
||||
'POST': 'settings.change_chatai'
|
||||
}
|
||||
|
||||
def get_config(self, request):
|
||||
serializer = self.serializer_class(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
data = self.serializer_class().data
|
||||
data.update(serializer.validated_data)
|
||||
for k, v in data.items():
|
||||
if v:
|
||||
continue
|
||||
# 页面没有传递值, 从 settings 中获取
|
||||
data[k] = getattr(settings, k, None)
|
||||
return data
|
||||
|
||||
def post(self, request):
|
||||
config = self.get_config(request)
|
||||
chat_ai_enabled = config['CHAT_AI_ENABLED']
|
||||
if not chat_ai_enabled:
|
||||
return Response(
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
data={'msg': _('Chat AI is not enabled')}
|
||||
)
|
||||
|
||||
tp = config['CHAT_AI_TYPE']
|
||||
if tp == ChatAITypeChoices.gpt:
|
||||
url = config['GPT_BASE_URL']
|
||||
api_key = config['GPT_API_KEY']
|
||||
proxy = config['GPT_PROXY']
|
||||
model = config['GPT_MODEL']
|
||||
else:
|
||||
url = config['DEEPSEEK_BASE_URL']
|
||||
api_key = config['DEEPSEEK_API_KEY']
|
||||
proxy = config['DEEPSEEK_PROXY']
|
||||
model = config['DEEPSEEK_MODEL']
|
||||
|
||||
kwargs = {
|
||||
'base_url': url or None,
|
||||
'api_key': api_key,
|
||||
}
|
||||
try:
|
||||
if proxy:
|
||||
kwargs['http_client'] = httpx.Client(
|
||||
proxies=proxy,
|
||||
transport=httpx.HTTPTransport(local_address='0.0.0.0')
|
||||
)
|
||||
client = openai.OpenAI(**kwargs)
|
||||
|
||||
ok = False
|
||||
error = ''
|
||||
|
||||
client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Say this is a test",
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
)
|
||||
ok = True
|
||||
except openai.APIConnectionError as e:
|
||||
error = str(e.__cause__) # an underlying Exception, likely raised within httpx.
|
||||
except openai.APIStatusError as e:
|
||||
error = str(e.message)
|
||||
except Exception as e:
|
||||
ok, error = False, str(e)
|
||||
|
||||
if ok:
|
||||
_status, msg = status.HTTP_200_OK, _('Test success')
|
||||
else:
|
||||
_status, msg = status.HTTP_400_BAD_REQUEST, error
|
||||
|
||||
return Response(status=_status, data={'msg': msg})
|
||||
|
||||
|
||||
class ChatPromptViewSet(JMSModelViewSet):
|
||||
serializer_classes = {
|
||||
'default': serializers.ChatPromptSerializer,
|
||||
|
||||
@@ -154,7 +154,10 @@ class SettingsApi(generics.RetrieveUpdateAPIView):
|
||||
def parse_serializer_data(self, serializer):
|
||||
data = []
|
||||
fields = self.get_fields()
|
||||
encrypted_items = [name for name, field in fields.items() if field.write_only]
|
||||
encrypted_items = [
|
||||
name for name, field in fields.items()
|
||||
if field.write_only or getattr(field, 'encrypted', False)
|
||||
]
|
||||
category = self.request.query_params.get('category', '')
|
||||
for name, value in serializer.validated_data.items():
|
||||
encrypted = name in encrypted_items
|
||||
|
||||
@@ -14,18 +14,5 @@ class ChatAIMethodChoices(TextChoices):
|
||||
|
||||
|
||||
class ChatAITypeChoices(TextChoices):
|
||||
gpt = 'gpt', 'GPT'
|
||||
deep_seek = 'deep-seek', 'DeepSeek'
|
||||
|
||||
|
||||
class GPTModelChoices(TextChoices):
|
||||
gpt_4o_mini = 'gpt-4o-mini', 'gpt-4o-mini'
|
||||
gpt_4o = 'gpt-4o', 'gpt-4o'
|
||||
o3_mini = 'o3-mini', 'o3-mini'
|
||||
o1_mini = 'o1-mini', 'o1-mini'
|
||||
o1 = 'o1', 'o1'
|
||||
|
||||
|
||||
class DeepSeekModelChoices(TextChoices):
|
||||
deepseek_chat = 'deepseek-chat', 'DeepSeek-V3'
|
||||
deepseek_reasoner = 'deepseek-reasoner', 'DeepSeek-R1'
|
||||
openai = 'openai', 'Openai'
|
||||
ollama = 'ollama', 'Ollama'
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.files.base import ContentFile
|
||||
@@ -14,7 +15,6 @@ from rest_framework.utils.encoders import JSONEncoder
|
||||
from common.db.models import JMSBaseModel
|
||||
from common.db.utils import Encryptor
|
||||
from common.utils import get_logger
|
||||
from .const import ChatAITypeChoices
|
||||
from .signals import setting_changed
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -196,20 +196,25 @@ class ChatPrompt(JMSBaseModel):
|
||||
return self.name
|
||||
|
||||
|
||||
def get_chatai_data():
|
||||
data = {
|
||||
'url': settings.GPT_BASE_URL,
|
||||
'api_key': settings.GPT_API_KEY,
|
||||
'proxy': settings.GPT_PROXY,
|
||||
'model': settings.GPT_MODEL,
|
||||
}
|
||||
if settings.CHAT_AI_TYPE != ChatAITypeChoices.gpt:
|
||||
data['url'] = settings.DEEPSEEK_BASE_URL
|
||||
data['api_key'] = settings.DEEPSEEK_API_KEY
|
||||
data['proxy'] = settings.DEEPSEEK_PROXY
|
||||
data['model'] = settings.DEEPSEEK_MODEL
|
||||
def get_chatai_data() -> Dict[str, Any]:
|
||||
raw_providers = settings.CHAT_AI_PROVIDERS
|
||||
providers: List[dict] = [p for p in raw_providers if isinstance(p, dict)]
|
||||
|
||||
return data
|
||||
if not providers:
|
||||
return {}
|
||||
|
||||
selected = next(
|
||||
(p for p in providers if p.get('is_assistant')),
|
||||
providers[0],
|
||||
)
|
||||
|
||||
return {
|
||||
'url': selected.get('base_url'),
|
||||
'api_key': selected.get('api_key'),
|
||||
'proxy': selected.get('proxy'),
|
||||
'model': selected.get('model'),
|
||||
'name': selected.get('name'),
|
||||
}
|
||||
|
||||
|
||||
def init_sqlite_db():
|
||||
|
||||
@@ -10,11 +10,12 @@ from common.utils import date_expired_default
|
||||
__all__ = [
|
||||
'AnnouncementSettingSerializer', 'OpsSettingSerializer', 'VaultSettingSerializer',
|
||||
'HashicorpKVSerializer', 'AzureKVSerializer', 'TicketSettingSerializer',
|
||||
'ChatAISettingSerializer', 'VirtualAppSerializer', 'AmazonSMSerializer',
|
||||
'ChatAIProviderSerializer', 'ChatAISettingSerializer',
|
||||
'VirtualAppSerializer', 'AmazonSMSerializer',
|
||||
]
|
||||
|
||||
from settings.const import (
|
||||
ChatAITypeChoices, GPTModelChoices, DeepSeekModelChoices, ChatAIMethodChoices
|
||||
ChatAITypeChoices, ChatAIMethodChoices
|
||||
)
|
||||
|
||||
|
||||
@@ -120,6 +121,29 @@ class AmazonSMSerializer(serializers.Serializer):
|
||||
)
|
||||
|
||||
|
||||
class ChatAIProviderListSerializer(serializers.ListSerializer):
|
||||
# 标记整个列表需要加密存储,避免明文保存 API Key
|
||||
encrypted = True
|
||||
|
||||
|
||||
class ChatAIProviderSerializer(serializers.Serializer):
|
||||
type = serializers.ChoiceField(
|
||||
default=ChatAITypeChoices.openai, choices=ChatAITypeChoices.choices,
|
||||
label=_("Types"), required=False,
|
||||
)
|
||||
base_url = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Base URL'),
|
||||
help_text=_('The base URL of the Chat service.')
|
||||
)
|
||||
api_key = EncryptedField(
|
||||
allow_blank=True, required=False, label=_('API Key'),
|
||||
)
|
||||
proxy = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Proxy'),
|
||||
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
|
||||
)
|
||||
|
||||
|
||||
class ChatAISettingSerializer(serializers.Serializer):
|
||||
PREFIX_TITLE = _('Chat AI')
|
||||
|
||||
@@ -130,44 +154,14 @@ class ChatAISettingSerializer(serializers.Serializer):
|
||||
default=ChatAIMethodChoices.api, choices=ChatAIMethodChoices.choices,
|
||||
label=_("Method"), required=False,
|
||||
)
|
||||
CHAT_AI_PROVIDERS = ChatAIProviderListSerializer(
|
||||
child=ChatAIProviderSerializer(),
|
||||
allow_empty=True, required=False, default=list, label=_('Providers')
|
||||
)
|
||||
CHAT_AI_EMBED_URL = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Base URL'),
|
||||
help_text=_('The base URL of the Chat service.')
|
||||
)
|
||||
CHAT_AI_TYPE = serializers.ChoiceField(
|
||||
default=ChatAITypeChoices.gpt, choices=ChatAITypeChoices.choices,
|
||||
label=_("Types"), required=False,
|
||||
)
|
||||
GPT_BASE_URL = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Base URL'),
|
||||
help_text=_('The base URL of the Chat service.')
|
||||
)
|
||||
GPT_API_KEY = EncryptedField(
|
||||
allow_blank=True, required=False, label=_('API Key'),
|
||||
)
|
||||
GPT_PROXY = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Proxy'),
|
||||
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
|
||||
)
|
||||
GPT_MODEL = serializers.ChoiceField(
|
||||
default=GPTModelChoices.gpt_4o_mini, choices=GPTModelChoices.choices,
|
||||
label=_("GPT Model"), required=False,
|
||||
)
|
||||
DEEPSEEK_BASE_URL = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Base URL'),
|
||||
help_text=_('The base URL of the Chat service.')
|
||||
)
|
||||
DEEPSEEK_API_KEY = EncryptedField(
|
||||
allow_blank=True, required=False, label=_('API Key'),
|
||||
)
|
||||
DEEPSEEK_PROXY = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Proxy'),
|
||||
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
|
||||
)
|
||||
DEEPSEEK_MODEL = serializers.ChoiceField(
|
||||
default=DeepSeekModelChoices.deepseek_chat, choices=DeepSeekModelChoices.choices,
|
||||
label=_("DeepSeek Model"), required=False,
|
||||
)
|
||||
|
||||
|
||||
class TicketSettingSerializer(serializers.Serializer):
|
||||
|
||||
@@ -73,7 +73,6 @@ class PrivateSettingSerializer(PublicSettingSerializer):
|
||||
CHAT_AI_ENABLED = serializers.BooleanField()
|
||||
CHAT_AI_METHOD = serializers.CharField()
|
||||
CHAT_AI_EMBED_URL = serializers.CharField()
|
||||
CHAT_AI_TYPE = serializers.CharField()
|
||||
GPT_MODEL = serializers.CharField()
|
||||
FILE_UPLOAD_SIZE_LIMIT_MB = serializers.IntegerField()
|
||||
FTP_FILE_MAX_STORE = serializers.IntegerField()
|
||||
|
||||
@@ -21,7 +21,6 @@ urlpatterns = [
|
||||
path('sms/<str:backend>/testing/', api.SMSTestingAPI.as_view(), name='sms-testing'),
|
||||
path('sms/backend/', api.SMSBackendAPI.as_view(), name='sms-backend'),
|
||||
path('vault/<str:backend>/testing/', api.VaultTestingAPI.as_view(), name='vault-testing'),
|
||||
path('chatai/testing/', api.ChatAITestingAPI.as_view(), name='chatai-testing'),
|
||||
path('vault/sync/', api.VaultSyncDataAPI.as_view(), name='vault-sync'),
|
||||
path('security/block-ip/', api.BlockIPSecurityAPI.as_view(), name='block-ip'),
|
||||
path('security/unlock-ip/', api.UnlockIPSecurityAPI.as_view(), name='unlock-ip'),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from .applet import *
|
||||
from .chat import *
|
||||
from .component import *
|
||||
from .session import *
|
||||
from .virtualapp import *
|
||||
|
||||
1
apps/terminal/api/chat/__init__.py
Normal file
1
apps/terminal/api/chat/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .chat import *
|
||||
15
apps/terminal/api/chat/chat.py
Normal file
15
apps/terminal/api/chat/chat.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from common.api import JMSBulkModelViewSet
|
||||
from terminal import serializers
|
||||
from terminal.filters import ChatFilter
|
||||
from terminal.models import Chat
|
||||
|
||||
__all__ = ['ChatViewSet']
|
||||
|
||||
|
||||
class ChatViewSet(JMSBulkModelViewSet):
|
||||
queryset = Chat.objects.all()
|
||||
serializer_class = serializers.ChatSerializer
|
||||
filterset_class = ChatFilter
|
||||
search_fields = ['title']
|
||||
ordering_fields = ['date_updated']
|
||||
ordering = ['-date_updated']
|
||||
@@ -2,7 +2,7 @@ from django.db.models import QuerySet
|
||||
from django_filters import rest_framework as filters
|
||||
|
||||
from orgs.utils import filter_org_queryset
|
||||
from terminal.models import Command, CommandStorage, Session
|
||||
from terminal.models import Command, CommandStorage, Session, Chat
|
||||
|
||||
|
||||
class CommandFilter(filters.FilterSet):
|
||||
@@ -79,7 +79,34 @@ class CommandStorageFilter(filters.FilterSet):
|
||||
model = CommandStorage
|
||||
fields = ['real', 'name', 'type', 'is_default']
|
||||
|
||||
def filter_real(self, queryset, name, value):
|
||||
@staticmethod
|
||||
def filter_real(queryset, name, value):
|
||||
if value:
|
||||
queryset = queryset.exclude(name='null')
|
||||
return queryset
|
||||
|
||||
|
||||
class ChatFilter(filters.FilterSet):
|
||||
ids = filters.BooleanFilter(method='filter_ids')
|
||||
folder_ids = filters.BooleanFilter(method='filter_folder_ids')
|
||||
|
||||
|
||||
class Meta:
|
||||
model = Chat
|
||||
fields = [
|
||||
'title', 'user_id', 'pinned', 'folder_id',
|
||||
'archived', 'socket_id', 'share_id'
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def filter_ids(queryset, name, value):
|
||||
ids = value.split(',')
|
||||
queryset = queryset.filter(id__in=ids)
|
||||
return queryset
|
||||
|
||||
|
||||
@staticmethod
|
||||
def filter_folder_ids(queryset, name, value):
|
||||
ids = value.split(',')
|
||||
queryset = queryset.filter(folder_id__in=ids)
|
||||
return queryset
|
||||
|
||||
38
apps/terminal/migrations/0011_chat.py
Normal file
38
apps/terminal/migrations/0011_chat.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Generated by Django 4.1.13 on 2025-09-30 06:57
|
||||
|
||||
from django.db import migrations, models
|
||||
import uuid
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('terminal', '0010_alter_command_risk_level_alter_session_login_from_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='Chat',
|
||||
fields=[
|
||||
('created_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Created by')),
|
||||
('updated_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Updated by')),
|
||||
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
|
||||
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
|
||||
('comment', models.TextField(blank=True, default='', verbose_name='Comment')),
|
||||
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
|
||||
('title', models.CharField(max_length=256, verbose_name='Title')),
|
||||
('chat', models.JSONField(default=dict, verbose_name='Chat')),
|
||||
('meta', models.JSONField(default=dict, verbose_name='Meta')),
|
||||
('pinned', models.BooleanField(default=False, verbose_name='Pinned')),
|
||||
('archived', models.BooleanField(default=False, verbose_name='Archived')),
|
||||
('share_id', models.CharField(blank=True, default='', max_length=36)),
|
||||
('folder_id', models.CharField(blank=True, default='', max_length=36)),
|
||||
('socket_id', models.CharField(blank=True, default='', max_length=36)),
|
||||
('user_id', models.CharField(blank=True, db_index=True, default='', max_length=36)),
|
||||
('session_info', models.JSONField(default=dict, verbose_name='Session Info')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Chat',
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -1,4 +1,5 @@
|
||||
from .session import *
|
||||
from .component import *
|
||||
from .applet import *
|
||||
from .chat import *
|
||||
from .component import *
|
||||
from .session import *
|
||||
from .virtualapp import *
|
||||
|
||||
1
apps/terminal/models/chat/__init__.py
Normal file
1
apps/terminal/models/chat/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .chat import *
|
||||
30
apps/terminal/models/chat/chat.py
Normal file
30
apps/terminal/models/chat/chat.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common.db.models import JMSBaseModel
|
||||
from common.utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
__all__ = ['Chat']
|
||||
|
||||
|
||||
class Chat(JMSBaseModel):
|
||||
# id == session_id # 36 chars
|
||||
title = models.CharField(max_length=256, verbose_name=_('Title'))
|
||||
chat = models.JSONField(default=dict, verbose_name=_('Chat'))
|
||||
meta = models.JSONField(default=dict, verbose_name=_('Meta'))
|
||||
pinned = models.BooleanField(default=False, verbose_name=_('Pinned'))
|
||||
archived = models.BooleanField(default=False, verbose_name=_('Archived'))
|
||||
share_id = models.CharField(blank=True, default='', max_length=36)
|
||||
folder_id = models.CharField(blank=True, default='', max_length=36)
|
||||
socket_id = models.CharField(blank=True, default='', max_length=36)
|
||||
user_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
|
||||
|
||||
session_info = models.JSONField(default=dict, verbose_name=_('Session Info'))
|
||||
|
||||
class Meta:
|
||||
verbose_name = _('Chat')
|
||||
|
||||
def __str__(self):
|
||||
return self.title
|
||||
@@ -123,11 +123,11 @@ class Terminal(StorageMixin, TerminalStatusMixin, JMSBaseModel):
|
||||
def get_chat_ai_setting():
|
||||
data = get_chatai_data()
|
||||
return {
|
||||
'GPT_BASE_URL': data['url'],
|
||||
'GPT_API_KEY': data['api_key'],
|
||||
'GPT_PROXY': data['proxy'],
|
||||
'GPT_MODEL': data['model'],
|
||||
'CHAT_AI_TYPE': settings.CHAT_AI_TYPE,
|
||||
'GPT_BASE_URL': data.get('url'),
|
||||
'GPT_API_KEY': data.get('api_key'),
|
||||
'GPT_PROXY': data.get('proxy'),
|
||||
'GPT_MODEL': data.get('model'),
|
||||
'CHAT_AI_PROVIDERS': settings.CHAT_AI_PROVIDERS,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
#
|
||||
from .applet import *
|
||||
from .applet_host import *
|
||||
from .chat import *
|
||||
from .command import *
|
||||
from .endpoint import *
|
||||
from .loki import *
|
||||
from .session import *
|
||||
from .sharing import *
|
||||
from .storage import *
|
||||
@@ -11,4 +13,3 @@ from .task import *
|
||||
from .terminal import *
|
||||
from .virtualapp import *
|
||||
from .virtualapp_provider import *
|
||||
from .loki import *
|
||||
|
||||
28
apps/terminal/serializers/chat.py
Normal file
28
apps/terminal/serializers/chat.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.serializers import CommonBulkModelSerializer
|
||||
from terminal.models import Chat
|
||||
|
||||
__all__ = ['ChatSerializer']
|
||||
|
||||
|
||||
class ChatSerializer(CommonBulkModelSerializer):
|
||||
created_at = serializers.SerializerMethodField()
|
||||
updated_at = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = Chat
|
||||
fields_mini = ['id', 'title', 'created_at', 'updated_at']
|
||||
fields = fields_mini + [
|
||||
'chat', 'meta', 'pinned', 'archived',
|
||||
'share_id', 'folder_id',
|
||||
'user_id', 'session_info'
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_created_at(obj):
|
||||
return int(obj.date_created.timestamp())
|
||||
|
||||
@staticmethod
|
||||
def get_updated_at(obj):
|
||||
return int(obj.date_updated.timestamp())
|
||||
@@ -32,6 +32,7 @@ router.register(r'virtual-apps', api.VirtualAppViewSet, 'virtual-app')
|
||||
router.register(r'app-providers', api.AppProviderViewSet, 'app-provider')
|
||||
router.register(r'app-providers/((?P<provider>[^/.]+)/)?apps', api.AppProviderAppViewSet, 'app-provider-app')
|
||||
router.register(r'virtual-app-publications', api.VirtualAppPublicationViewSet, 'virtual-app-publication')
|
||||
router.register(r'chats', api.ChatViewSet, 'chat')
|
||||
|
||||
urlpatterns = [
|
||||
path('my-sessions/', api.MySessionAPIView.as_view(), name='my-session'),
|
||||
|
||||
@@ -199,11 +199,19 @@ class UserChangePasswordApi(UserQuerysetMixin, generics.UpdateAPIView):
|
||||
class UserUnblockPKApi(UserQuerysetMixin, generics.UpdateAPIView):
|
||||
serializer_class = serializers.UserSerializer
|
||||
|
||||
def get_object(self):
|
||||
pk = self.kwargs.get('pk')
|
||||
if is_uuid(pk):
|
||||
return super().get_object()
|
||||
else:
|
||||
return self.get_queryset().filter(username=pk).first()
|
||||
|
||||
def perform_update(self, serializer):
|
||||
user = self.get_object()
|
||||
username = user.username if user else ''
|
||||
LoginBlockUtil.unblock_user(username)
|
||||
MFABlockUtils.unblock_user(username)
|
||||
if not user:
|
||||
return Response({"error": _("User not found")}, status=404)
|
||||
|
||||
user.unblock_login()
|
||||
|
||||
|
||||
class UserResetMFAApi(UserQuerysetMixin, generics.RetrieveAPIView):
|
||||
|
||||
@@ -274,8 +274,8 @@ class User(
|
||||
LoginBlockUtil.unblock_user(self.username)
|
||||
MFABlockUtils.unblock_user(self.username)
|
||||
|
||||
@lazyproperty
|
||||
def login_blocked(self):
|
||||
@property
|
||||
def is_login_blocked(self):
|
||||
from users.utils import LoginBlockUtil, MFABlockUtils
|
||||
|
||||
if LoginBlockUtil.is_user_block(self.username):
|
||||
@@ -284,6 +284,13 @@ class User(
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def block_login(cls, username):
|
||||
from users.utils import LoginBlockUtil, MFABlockUtils
|
||||
|
||||
LoginBlockUtil.block_user(username)
|
||||
MFABlockUtils.block_user(username)
|
||||
|
||||
def delete(self, using=None, keep_parents=False):
|
||||
if self.pk == 1 or self.username == "admin":
|
||||
raise PermissionDenied(_("Can not delete admin user"))
|
||||
|
||||
@@ -123,7 +123,7 @@ class UserSerializer(
|
||||
mfa_force_enabled = serializers.BooleanField(
|
||||
read_only=True, label=_("MFA force enabled")
|
||||
)
|
||||
login_blocked = serializers.BooleanField(read_only=True, label=_("Login blocked"))
|
||||
is_login_blocked = serializers.BooleanField(read_only=True, label=_("Login blocked"))
|
||||
is_expired = serializers.BooleanField(read_only=True, label=_("Is expired"))
|
||||
is_valid = serializers.BooleanField(read_only=True, label=_("Is valid"))
|
||||
is_otp_secret_key_bound = serializers.BooleanField(
|
||||
@@ -193,6 +193,7 @@ class UserSerializer(
|
||||
"is_valid", "is_expired", "is_active", # 布尔字段
|
||||
"is_otp_secret_key_bound", "can_public_key_auth",
|
||||
"mfa_enabled", "need_update_password", "is_face_code_set",
|
||||
"is_login_blocked",
|
||||
]
|
||||
# 包含不太常用的字段,可以没有
|
||||
fields_verbose = (
|
||||
@@ -211,7 +212,7 @@ class UserSerializer(
|
||||
# 多对多字段
|
||||
fields_m2m = ["groups", "system_roles", "org_roles", "orgs_roles", "labels"]
|
||||
# 在serializer 上定义的字段
|
||||
fields_custom = ["login_blocked", "password_strategy"]
|
||||
fields_custom = ["is_login_blocked", "password_strategy"]
|
||||
fields = fields_verbose + fields_fk + fields_m2m + fields_custom
|
||||
fields_unexport = ["avatar_url", "is_service_account"]
|
||||
|
||||
|
||||
@@ -28,6 +28,6 @@ urlpatterns = [
|
||||
path('users/<uuid:pk>/password/', api.UserChangePasswordApi.as_view(), name='change-user-password'),
|
||||
path('users/<uuid:pk>/password/reset/', api.UserResetPasswordApi.as_view(), name='user-reset-password'),
|
||||
path('users/<uuid:pk>/pubkey/reset/', api.UserResetPKApi.as_view(), name='user-public-key-reset'),
|
||||
path('users/<uuid:pk>/unblock/', api.UserUnblockPKApi.as_view(), name='user-unblock'),
|
||||
path('users/<str:pk>/unblock/', api.UserUnblockPKApi.as_view(), name='user-unblock'),
|
||||
]
|
||||
urlpatterns += router.urls
|
||||
|
||||
@@ -186,6 +186,13 @@ class BlockUtilBase:
|
||||
def is_block(self):
|
||||
return bool(cache.get(self.block_key))
|
||||
|
||||
@classmethod
|
||||
def block_user(cls, username):
|
||||
username = username.lower()
|
||||
block_key = cls.BLOCK_KEY_TMPL.format(username)
|
||||
key_ttl = int(settings.SECURITY_LOGIN_LIMIT_TIME) * 60
|
||||
cache.set(block_key, True, key_ttl)
|
||||
|
||||
@classmethod
|
||||
def get_blocked_usernames(cls):
|
||||
key = cls.BLOCK_KEY_TMPL.format('*')
|
||||
|
||||
Reference in New Issue
Block a user