mirror of
https://github.com/jumpserver/jumpserver.git
synced 2026-03-18 11:02:09 +00:00
refactor: api schema generator
This commit is contained in:
@@ -254,7 +254,7 @@ class UserLoginView(mixins.AuthMixin, UserLoginContextMixin, FormView):
|
||||
else:
|
||||
return get_user_login_form_cls()
|
||||
|
||||
def get_comprehensive_form_class(self):
|
||||
def get_form_class_comprehensive(self):
|
||||
return get_comprehensive_user_login_form_cls()
|
||||
|
||||
def clear_rsa_key(self):
|
||||
|
||||
2547
utils/api_schema_generator/output/webui_schema.json
Normal file
2547
utils/api_schema_generator/output/webui_schema.json
Normal file
File diff suppressed because it is too large
Load Diff
0
utils/api_schema_generator/src/__init__.py
Normal file
0
utils/api_schema_generator/src/__init__.py
Normal file
59
utils/api_schema_generator/src/extractors/base.py
Normal file
59
utils/api_schema_generator/src/extractors/base.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from django.core.handlers.asgi import ASGIRequest
|
||||
|
||||
from routing.discover import Route
|
||||
from routing.resolver import View
|
||||
from schema.endpoint import Endpoint, MethodSchema
|
||||
|
||||
|
||||
class BaseExtractor:
|
||||
|
||||
def __init__(self, view: View):
|
||||
self.view = view
|
||||
self.fake_request = self.get_fake_request()
|
||||
|
||||
def get_fake_request(self) -> ASGIRequest:
|
||||
scope = {
|
||||
'type': 'http',
|
||||
'method': 'GET',
|
||||
'path': '/',
|
||||
'query_string': b'',
|
||||
'headers': [],
|
||||
}
|
||||
async def receive():
|
||||
return {'type': 'http.request', 'body': b''}
|
||||
fake_request = ASGIRequest(scope, receive)
|
||||
setattr(fake_request, 'query_params', {})
|
||||
return fake_request
|
||||
|
||||
def extract(self) -> Endpoint:
|
||||
url = self.view.route.path
|
||||
if url.startswith('/api/v1/users/^users'):
|
||||
pass
|
||||
endpoint = Endpoint(
|
||||
path=self.view.route.path,
|
||||
requires_auth=self.view_requires_auth()
|
||||
)
|
||||
methods = self.get_http_methods()
|
||||
for method in methods:
|
||||
query_fields = self.extract_query_fields(method)
|
||||
body_fields = self.extract_body_fields(method)
|
||||
method_schema = MethodSchema(
|
||||
method=method,
|
||||
query_fields=query_fields,
|
||||
body_fields=body_fields,
|
||||
)
|
||||
endpoint.methods[method] = method_schema
|
||||
return endpoint
|
||||
|
||||
def get_http_methods(self) -> list:
|
||||
return ['GET', 'POST']
|
||||
|
||||
def view_requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
def extract_query_fields(self, method: str) -> list:
|
||||
return []
|
||||
|
||||
def extract_body_fields(self, method: str) -> list:
|
||||
return []
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
from django import forms
|
||||
from django.contrib.auth.mixins import LoginRequiredMixin
|
||||
|
||||
from .base import BaseExtractor
|
||||
from .django_view import DjangoViewExtractor
|
||||
from routing.discover import Route
|
||||
from routing.resolver import View
|
||||
from schema.endpoint import QueryField, BodyField
|
||||
|
||||
|
||||
class DjangoFormViewExtractor(DjangoViewExtractor):
|
||||
|
||||
def extract_body_fields(self, method: str):
|
||||
form_class = self.get_form_class()
|
||||
if not form_class:
|
||||
return []
|
||||
form: forms.Form = form_class()
|
||||
body_fields = []
|
||||
for field_name, field in form.fields.items():
|
||||
field: forms.Field
|
||||
body_field = BodyField(
|
||||
name=field_name,
|
||||
field_type=self.get_field_type(field),
|
||||
required=field.required,
|
||||
description=str(field.help_text) or '',
|
||||
)
|
||||
body_fields.append(body_field)
|
||||
return body_fields
|
||||
|
||||
def get_field_type(self, form_field):
|
||||
field_type_mapping = {
|
||||
forms.CharField: 'string',
|
||||
forms.EmailField: 'string',
|
||||
forms.URLField: 'string',
|
||||
forms.SlugField: 'string',
|
||||
forms.UUIDField: 'string',
|
||||
forms.RegexField: 'string',
|
||||
forms.FileField: 'string',
|
||||
forms.ImageField: 'string',
|
||||
forms.FilePathField: 'string',
|
||||
forms.GenericIPAddressField: 'string',
|
||||
forms.IntegerField: 'integer',
|
||||
forms.FloatField: 'number',
|
||||
forms.DecimalField: 'number',
|
||||
forms.BooleanField: 'boolean',
|
||||
forms.NullBooleanField: 'boolean',
|
||||
forms.DateField: 'string',
|
||||
forms.TimeField: 'string',
|
||||
forms.DateTimeField: 'string',
|
||||
forms.DurationField: 'string',
|
||||
forms.MultipleChoiceField: 'array',
|
||||
forms.TypedMultipleChoiceField: 'array',
|
||||
forms.ModelMultipleChoiceField: 'array',
|
||||
forms.ChoiceField: 'string',
|
||||
forms.TypedChoiceField: 'string',
|
||||
forms.ModelChoiceField: 'string',
|
||||
forms.JSONField: 'object',
|
||||
}
|
||||
for field_type, json_type in field_type_mapping.items():
|
||||
if issubclass(type(form_field), field_type):
|
||||
return json_type
|
||||
return 'string'
|
||||
# raise ValueError(f"Unsupported form field type: {type(form_field)}")
|
||||
|
||||
def get_form_class(self):
|
||||
view = self.view
|
||||
form_class = getattr(view.view_class, 'form_class', None)
|
||||
if form_class:
|
||||
return form_class
|
||||
|
||||
view_instance = view.view_class(request=self.fake_request)
|
||||
|
||||
if hasattr(view_instance, 'get_form_class_comprehensive'):
|
||||
form_class = view_instance.get_form_class_comprehensive()
|
||||
return form_class
|
||||
|
||||
if hasattr(view_instance, 'get_form_class'):
|
||||
form_class = view_instance.get_form_class()
|
||||
return form_class
|
||||
@@ -0,0 +1,5 @@
|
||||
from .base import BaseExtractor
|
||||
|
||||
|
||||
class DjangoFunctionViewExtractor(BaseExtractor):
|
||||
pass
|
||||
57
utils/api_schema_generator/src/extractors/django_view.py
Normal file
57
utils/api_schema_generator/src/extractors/django_view.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from django.contrib.auth.mixins import LoginRequiredMixin
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rest_framework.permissions import OperandHolder, AND, OR, NOT
|
||||
from rbac.permissions import RBACPermission
|
||||
|
||||
from .base import BaseExtractor
|
||||
|
||||
__all__ = ['DjangoViewExtractor']
|
||||
|
||||
|
||||
class DjangoViewExtractor(BaseExtractor):
|
||||
|
||||
def view_requires_auth(self):
|
||||
if issubclass(self.view.view_class, LoginRequiredMixin):
|
||||
return True
|
||||
permission_classes = getattr(self.view.view_class, 'permission_classes', [])
|
||||
if not permission_classes:
|
||||
return False
|
||||
return self.check_view_permission_classes_requires_auth(permission_classes)
|
||||
|
||||
def check_view_permission_classes_requires_auth(self, permission_classes, operator=AND):
|
||||
if operator == AND:
|
||||
for pc in permission_classes:
|
||||
if self.check_view_permission_class_requires_auth(pc):
|
||||
return True
|
||||
return False
|
||||
elif operator == OR:
|
||||
for pc in permission_classes:
|
||||
if not self.check_view_permission_class_requires_auth(pc):
|
||||
return False
|
||||
return True
|
||||
elif operator == NOT:
|
||||
raise ValueError('NOT operator is not supported in permission_classes')
|
||||
else:
|
||||
return False
|
||||
|
||||
def check_view_permission_class_requires_auth(self, permission_class):
|
||||
if isinstance(permission_class, OperandHolder):
|
||||
operator = permission_class.operator_class
|
||||
op1_class = permission_class.op1_class
|
||||
op2_class = permission_class.op2_class
|
||||
permission_classes = [op1_class, op2_class]
|
||||
return self.check_view_permission_classes_requires_auth(permission_classes, operator)
|
||||
else:
|
||||
if issubclass(permission_class, (IsAuthenticated, RBACPermission)):
|
||||
return True
|
||||
if issubclass(permission_class, (AllowAny, )):
|
||||
return False
|
||||
|
||||
permission_class_name: str = getattr(permission_class, '__name__', None)
|
||||
if not permission_class_name:
|
||||
return False
|
||||
if 'Authenticated' in permission_class_name:
|
||||
return True
|
||||
if permission_class_name.startswith('UserConfirmation'):
|
||||
return True
|
||||
return False
|
||||
81
utils/api_schema_generator/src/extractors/drf_api_view.py
Normal file
81
utils/api_schema_generator/src/extractors/drf_api_view.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from rest_framework import serializers
|
||||
from .base import BaseExtractor
|
||||
from .django_view import DjangoViewExtractor
|
||||
from schema.endpoint import BodyField
|
||||
|
||||
class DrfAPIViewExtractor(DjangoViewExtractor):
|
||||
|
||||
def extract_body_fields(self, method: str) -> list:
|
||||
serializer_class = self.get_serializer_class()
|
||||
if not serializer_class:
|
||||
return []
|
||||
serializer = serializer_class()
|
||||
fields = self.get_serializer_fields(serializer)
|
||||
body_fields = self.wrap_as_body_fields(fields)
|
||||
return body_fields
|
||||
|
||||
def wrap_as_body_fields(self, serializer_fields):
|
||||
if not hasattr(serializer_fields, 'items'):
|
||||
return []
|
||||
body_fields = []
|
||||
for field_name, field in serializer_fields.items():
|
||||
field: serializers.Field
|
||||
body_field = BodyField(
|
||||
name=field_name,
|
||||
field_type=self.get_field_type(field),
|
||||
required=field.required,
|
||||
description=str(field.help_text) or ''
|
||||
)
|
||||
if hasattr(field, 'child'):
|
||||
_body_fields = self.wrap_as_body_fields(field.child)
|
||||
body_field.extend_child(_body_fields)
|
||||
|
||||
body_fields.append(body_field)
|
||||
return body_fields
|
||||
|
||||
def get_field_type(self, serializer_field):
|
||||
if hasattr(serializer_field, 'child'):
|
||||
return 'object'
|
||||
|
||||
field_type_mapping = {
|
||||
serializers.CharField: 'string',
|
||||
serializers.EmailField: 'string',
|
||||
serializers.URLField: 'string',
|
||||
serializers.UUIDField: 'string',
|
||||
serializers.SlugField: 'string',
|
||||
serializers.ChoiceField: 'string',
|
||||
serializers.IntegerField: 'integer',
|
||||
serializers.FloatField: 'number',
|
||||
serializers.DecimalField: 'number',
|
||||
serializers.BooleanField: 'boolean',
|
||||
serializers.DateTimeField: 'string',
|
||||
serializers.DateField: 'string',
|
||||
serializers.TimeField: 'string',
|
||||
serializers.ListField: 'array',
|
||||
serializers.DictField: 'object',
|
||||
serializers.JSONField: 'object',
|
||||
}
|
||||
for field_type, json_type in field_type_mapping.items():
|
||||
if issubclass(type(serializer_field), field_type):
|
||||
return json_type
|
||||
return 'string'
|
||||
|
||||
def get_serializer_fields(self, serializer):
|
||||
try:
|
||||
fields = serializer.fields
|
||||
except Exception as e:
|
||||
fields = getattr(self.view.view_class, '_declared_fields', {})
|
||||
return fields
|
||||
|
||||
def get_serializer_class(self):
|
||||
serializer_class = getattr(self.view.view_class, 'serializer_class', None)
|
||||
if serializer_class:
|
||||
return serializer_class
|
||||
|
||||
view_instance = self.view.view_class(request=self.fake_request)
|
||||
if hasattr(view_instance, 'get_serializer_class'):
|
||||
try:
|
||||
serializer_class = view_instance.get_serializer_class()
|
||||
except Exception as e:
|
||||
serializer_class = None
|
||||
return serializer_class
|
||||
30
utils/api_schema_generator/src/extractors/selector.py
Normal file
30
utils/api_schema_generator/src/extractors/selector.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from django.views.generic.edit import FormView as DjangoFormView
|
||||
from django.views import View as DjangoView
|
||||
from rest_framework.views import APIView as DrfAPIView
|
||||
|
||||
from routing.resolver import View
|
||||
|
||||
from .django_func_view import DjangoFunctionViewExtractor
|
||||
from .django_form_view import DjangoFormViewExtractor
|
||||
from .django_view import DjangoViewExtractor
|
||||
from .drf_api_view import DrfAPIViewExtractor
|
||||
|
||||
|
||||
def select_extractor(view: View):
|
||||
|
||||
if view.is_func_based:
|
||||
return DjangoFunctionViewExtractor(view)
|
||||
|
||||
view_class = view.view_class
|
||||
|
||||
if issubclass(view_class, DrfAPIView):
|
||||
return DrfAPIViewExtractor(view)
|
||||
|
||||
if issubclass(view_class, DjangoFormView):
|
||||
return DjangoFormViewExtractor(view)
|
||||
|
||||
if issubclass(view_class, DjangoView):
|
||||
return DjangoViewExtractor(view)
|
||||
|
||||
raise NotImplementedError(f'Unsupported view class: {view_class}')
|
||||
|
||||
48
utils/api_schema_generator/src/main.py
Normal file
48
utils/api_schema_generator/src/main.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
import sys
|
||||
import django
|
||||
|
||||
# 获取项目根目录(jumpserver 目录)
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
APP_DIR = os.path.join(BASE_DIR, 'apps')
|
||||
|
||||
|
||||
# 不改变工作目录,直接加入 sys.path
|
||||
sys.path.insert(0, APP_DIR)
|
||||
sys.path.insert(0, BASE_DIR)
|
||||
|
||||
# 设置 Django 环境
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jumpserver.settings")
|
||||
django.setup()
|
||||
|
||||
|
||||
from routing.discover import discover_routes
|
||||
from routing.resolver import resolve_view
|
||||
from extractors.selector import select_extractor
|
||||
from schema.renderer import render_schema
|
||||
|
||||
|
||||
def generate_api_schema():
|
||||
# 发现所有路由
|
||||
routes = discover_routes()
|
||||
endpoints = []
|
||||
|
||||
for route in routes:
|
||||
# 解析视图
|
||||
view = resolve_view(route)
|
||||
# 选择视图提取器
|
||||
extractor = select_extractor(view)
|
||||
# 提取端点信息
|
||||
endpoint = extractor.extract()
|
||||
if not endpoint:
|
||||
continue
|
||||
endpoints.append(endpoint)
|
||||
|
||||
# 渲染最终的 API Schema
|
||||
schema = render_schema(endpoints)
|
||||
return schema
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
api_schema = generate_api_schema()
|
||||
print(api_schema)
|
||||
0
utils/api_schema_generator/src/routing/__init__.py
Normal file
0
utils/api_schema_generator/src/routing/__init__.py
Normal file
34
utils/api_schema_generator/src/routing/discover.py
Normal file
34
utils/api_schema_generator/src/routing/discover.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from django.urls import get_resolver
|
||||
from django.urls.resolvers import URLPattern, URLResolver
|
||||
|
||||
|
||||
__all__ = ['discover_routes']
|
||||
|
||||
|
||||
class Route:
|
||||
|
||||
def __init__(self, url_pattern: URLPattern, path_prefix):
|
||||
self.url_pattern = url_pattern
|
||||
self.path = f'{path_prefix}{url_pattern.pattern}'
|
||||
self.callback = url_pattern.callback
|
||||
|
||||
|
||||
def extract_url_patterns(patterns, path_prefix='/'):
|
||||
routes = []
|
||||
for p in patterns:
|
||||
if isinstance(p, URLResolver):
|
||||
_path_prefix = f'{path_prefix}{p.pattern}'
|
||||
_routes = extract_url_patterns(p.url_patterns, path_prefix=_path_prefix)
|
||||
routes.extend(_routes)
|
||||
elif isinstance(p, URLPattern):
|
||||
route = Route(url_pattern=p, path_prefix=path_prefix)
|
||||
routes.append(route)
|
||||
else:
|
||||
print(f'Skip: unknown pattern type: {type(p)}')
|
||||
return routes
|
||||
|
||||
|
||||
def discover_routes():
|
||||
resolver = get_resolver()
|
||||
routes = extract_url_patterns(resolver.url_patterns)
|
||||
return routes
|
||||
25
utils/api_schema_generator/src/routing/resolver.py
Normal file
25
utils/api_schema_generator/src/routing/resolver.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from .discover import Route
|
||||
|
||||
__all__ = ['resolve_view', 'View']
|
||||
|
||||
|
||||
class View:
|
||||
|
||||
def __init__(self, route: Route, view_func, view_class):
|
||||
self.route = route
|
||||
self.view_func = view_func
|
||||
self.view_class = view_class
|
||||
|
||||
@property
|
||||
def is_func_based(self):
|
||||
return self.view_class is None
|
||||
|
||||
|
||||
def resolve_view(route: Route) -> View:
|
||||
view_func = route.callback
|
||||
view_class = getattr(view_func, 'view_class', None)
|
||||
if not view_class:
|
||||
view_class = getattr(view_func, 'cls', None)
|
||||
view = View(route=route, view_func=view_func, view_class=view_class)
|
||||
return view
|
||||
|
||||
0
utils/api_schema_generator/src/schema/__init__.py
Normal file
0
utils/api_schema_generator/src/schema/__init__.py
Normal file
41
utils/api_schema_generator/src/schema/endpoint.py
Normal file
41
utils/api_schema_generator/src/schema/endpoint.py
Normal file
@@ -0,0 +1,41 @@
|
||||
|
||||
__all__ = ['Endpoint', 'MethodSchema', 'QueryField', 'BodyField']
|
||||
|
||||
|
||||
class Field:
|
||||
|
||||
def __init__(self, name, field_type, required=False, description=''):
|
||||
self.name = name
|
||||
self.field_type = field_type
|
||||
self.required = required
|
||||
self.description = description
|
||||
|
||||
|
||||
class QueryField(Field):
|
||||
pass
|
||||
|
||||
|
||||
class BodyField(Field):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.child = []
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def extend_child(self, child):
|
||||
self.child.extend(child)
|
||||
|
||||
|
||||
class MethodSchema:
|
||||
|
||||
def __init__(self, method, query_fields, body_fields):
|
||||
self.method = method
|
||||
self.query_fields = query_fields
|
||||
self.body_fields = body_fields
|
||||
|
||||
|
||||
class Endpoint:
|
||||
|
||||
def __init__(self, path, requires_auth=True):
|
||||
self.path = path
|
||||
self.methods = {}
|
||||
self.requires_auth = requires_auth
|
||||
71
utils/api_schema_generator/src/schema/renderer.py
Normal file
71
utils/api_schema_generator/src/schema/renderer.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import json
|
||||
from .endpoint import Endpoint, MethodSchema, QueryField, BodyField
|
||||
|
||||
dirname = os.path.dirname
|
||||
BASE_DIR = dirname(dirname(dirname(os.path.abspath(__file__))))
|
||||
OUTPUT_FILE_DIR = os.path.join(BASE_DIR, 'output')
|
||||
os.makedirs(OUTPUT_FILE_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def write_to_file(data, file_path):
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def write_webui_schema(endpoints):
|
||||
filename = 'webui_schema.json'
|
||||
file_path = os.path.join(OUTPUT_FILE_DIR, filename)
|
||||
data = {
|
||||
'GET': {},
|
||||
'POST': {}
|
||||
}
|
||||
for e in endpoints:
|
||||
e: Endpoint
|
||||
if e.requires_auth:
|
||||
continue
|
||||
url = e.path
|
||||
if url.startswith('/api/v1/users/^users'):
|
||||
pass
|
||||
for method, method_schema in e.methods.items():
|
||||
item = {
|
||||
'allowIf': 'prelogin',
|
||||
}
|
||||
method_schema: MethodSchema
|
||||
query_properties = {}
|
||||
for field in method_schema.query_fields:
|
||||
field: QueryField
|
||||
query_properties[field.name] = {
|
||||
'type': field.field_type,
|
||||
'description': field.description
|
||||
}
|
||||
query = {
|
||||
"type": "object",
|
||||
"properties": query_properties,
|
||||
"required": [],
|
||||
"additionalProperties": False
|
||||
}
|
||||
item['query'] = query
|
||||
if method in ['POST']:
|
||||
body_properties = {}
|
||||
for field in method_schema.body_fields:
|
||||
field: BodyField
|
||||
body_properties[field.name] = {
|
||||
'type': field.field_type,
|
||||
'description': field.description
|
||||
}
|
||||
body = {
|
||||
'type': 'object',
|
||||
'properties': body_properties,
|
||||
'required': [],
|
||||
"additionalProperties": False
|
||||
}
|
||||
item['body'] = body
|
||||
|
||||
data[method][url] = item
|
||||
|
||||
write_to_file(data, file_path)
|
||||
|
||||
|
||||
def render_schema(endpoints):
|
||||
write_webui_schema(endpoints)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .base import *
|
||||
from .form import *
|
||||
from .serializer import *
|
||||
12
utils/generate_view_schema/src/body_fields_generator/base.py
Normal file
12
utils/generate_view_schema/src/body_fields_generator/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
|
||||
|
||||
__all__ = ['FieldsGenerator']
|
||||
|
||||
|
||||
class FieldsGenerator:
|
||||
|
||||
def write_fields_schema(self):
|
||||
return {}
|
||||
|
||||
def required_fields(self):
|
||||
return []
|
||||
75
utils/generate_view_schema/src/body_fields_generator/form.py
Normal file
75
utils/generate_view_schema/src/body_fields_generator/form.py
Normal file
@@ -0,0 +1,75 @@
|
||||
|
||||
from common.utils import lazyproperty
|
||||
from .base import FieldsGenerator
|
||||
|
||||
__all__ = ['FormFieldsGenerator']
|
||||
|
||||
|
||||
class FormFieldsGenerator(FieldsGenerator):
|
||||
def __init__(self, raw_class, view):
|
||||
self.raw_class = raw_class
|
||||
self.view = view
|
||||
|
||||
@lazyproperty
|
||||
def fields(self):
|
||||
return self.raw_class().fields
|
||||
|
||||
def write_fields_schema(self):
|
||||
schema = self.get_fields_schema(self.fields)
|
||||
return schema
|
||||
|
||||
def required_fields(self):
|
||||
fields = [name for name, field in self.fields.items() if field.required]
|
||||
return fields
|
||||
|
||||
def get_fields_schema(self, fields):
|
||||
schemas = {}
|
||||
for name, field in fields.items():
|
||||
schema = {
|
||||
'type': self.get_field_type(field),
|
||||
}
|
||||
description = getattr(field, 'help_text', '')
|
||||
if description:
|
||||
schema['description'] = str(description)
|
||||
schemas[name] = schema
|
||||
return schemas
|
||||
|
||||
def get_field_type(self, field):
|
||||
"""将 Django Form Field 类型映射到 JSON Schema 类型"""
|
||||
from django import forms
|
||||
|
||||
type_mapping = {
|
||||
forms.CharField: 'string',
|
||||
forms.EmailField: 'string',
|
||||
forms.URLField: 'string',
|
||||
forms.SlugField: 'string',
|
||||
forms.UUIDField: 'string',
|
||||
forms.RegexField: 'string',
|
||||
forms.FileField: 'string',
|
||||
forms.ImageField: 'string',
|
||||
forms.FilePathField: 'string',
|
||||
forms.GenericIPAddressField: 'string',
|
||||
forms.IntegerField: 'integer',
|
||||
forms.FloatField: 'number',
|
||||
forms.DecimalField: 'number',
|
||||
forms.BooleanField: 'boolean',
|
||||
forms.NullBooleanField: 'boolean',
|
||||
forms.DateField: 'string',
|
||||
forms.TimeField: 'string',
|
||||
forms.DateTimeField: 'string',
|
||||
forms.DurationField: 'string',
|
||||
forms.MultipleChoiceField: 'array',
|
||||
forms.TypedMultipleChoiceField: 'array',
|
||||
forms.ModelMultipleChoiceField: 'array',
|
||||
forms.ChoiceField: 'string',
|
||||
forms.TypedChoiceField: 'string',
|
||||
forms.ModelChoiceField: 'string',
|
||||
forms.JSONField: 'object',
|
||||
}
|
||||
|
||||
for field_type, json_type in type_mapping.items():
|
||||
if isinstance(field, field_type):
|
||||
return json_type
|
||||
|
||||
return 'string'
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
from .base import FieldsGenerator
|
||||
from common.utils import lazyproperty
|
||||
|
||||
|
||||
__all__ = ['SerializerFieldsGenerator']
|
||||
|
||||
|
||||
class SerializerFieldsGenerator(FieldsGenerator):
|
||||
|
||||
def __init__(self, raw_class, view):
|
||||
self.raw_class = raw_class
|
||||
self.view = view
|
||||
|
||||
@lazyproperty
|
||||
def fields(self):
|
||||
fields = {}
|
||||
try:
|
||||
fields = self.raw_class().fields
|
||||
except Exception as e:
|
||||
if hasattr(self.raw_class, '_declared_fields'):
|
||||
fields = self.raw_class._declared_fields
|
||||
return fields
|
||||
|
||||
@lazyproperty
|
||||
def write_fields(self):
|
||||
fields = {}
|
||||
for name, field in self.fields.items():
|
||||
if field.read_only:
|
||||
continue
|
||||
fields[name] = field
|
||||
return fields
|
||||
|
||||
def get_fields_schema(self, fields):
|
||||
schemas = {}
|
||||
if not hasattr(fields, 'items'):
|
||||
return {}
|
||||
for name, field in fields.items():
|
||||
schema = self.get_field_schema(field)
|
||||
if hasattr(field, 'child'):
|
||||
_fields = field.child
|
||||
_fields_schema = self.get_fields_schema(_fields)
|
||||
schema['properties'] = _fields_schema
|
||||
schemas[name] = schema
|
||||
return schemas
|
||||
|
||||
def get_field_schema(self, field):
|
||||
schema = {
|
||||
'type': self.get_field_type(field),
|
||||
}
|
||||
if description := getattr(field, 'help_text', ''):
|
||||
schema['description'] = str(description)
|
||||
extra_schema = self.get_field_extra_schema(field)
|
||||
schema.update(extra_schema)
|
||||
return schema
|
||||
|
||||
def field_is_nested(self, field):
|
||||
from rest_framework import serializers
|
||||
nested_field_types = (
|
||||
serializers.Serializer,
|
||||
serializers.ModelSerializer,
|
||||
)
|
||||
return isinstance(field, nested_field_types)
|
||||
|
||||
def get_field_extra_schema(self, field):
|
||||
"""获取字段的正则表达式模式"""
|
||||
patterns = {}
|
||||
# 检查 validators 中是否有正则验证器
|
||||
if hasattr(field, 'validators'):
|
||||
for validator in field.validators:
|
||||
if hasattr(validator, 'regex'):
|
||||
patterns['pattern'] = str(validator.regex.pattern)
|
||||
break
|
||||
# 针对特定字段类型添加模式
|
||||
from rest_framework import serializers
|
||||
if isinstance(field, serializers.EmailField):
|
||||
patterns['format'] = 'email'
|
||||
elif isinstance(field, serializers.URLField):
|
||||
patterns['format'] = 'uri'
|
||||
elif isinstance(field, (serializers.DateTimeField, serializers.DateField, serializers.TimeField)):
|
||||
patterns['format'] = 'date-time'
|
||||
|
||||
# 添加长度限制
|
||||
if hasattr(field, 'max_length') and field.max_length:
|
||||
patterns['maxLength'] = field.max_length
|
||||
if hasattr(field, 'min_length') and field.min_length:
|
||||
patterns['minLength'] = field.min_length
|
||||
|
||||
# 添加数值范围
|
||||
if hasattr(field, 'max_value') and field.max_value is not None:
|
||||
patterns['maximum'] = field.max_value
|
||||
if hasattr(field, 'min_value') and field.min_value is not None:
|
||||
patterns['minimum'] = field.min_value
|
||||
|
||||
if choices := self.get_field_choices(field):
|
||||
patterns['enum'] = choices
|
||||
return patterns
|
||||
|
||||
def get_field_choices(self, field):
|
||||
from rest_framework import serializers
|
||||
choices = []
|
||||
field_need_query_db = isinstance(field, (
|
||||
serializers.PrimaryKeyRelatedField, # 会查询数据库
|
||||
serializers.StringRelatedField, # 会查询数据库
|
||||
serializers.SlugRelatedField, # 会查询数据库
|
||||
serializers.HyperlinkedRelatedField, # 会查询数据库
|
||||
serializers.HyperlinkedIdentityField,# 会查询数据库
|
||||
serializers.RelatedField, # 基类,会查询数据库
|
||||
serializers.ManyRelatedField, # 会查询数据库
|
||||
serializers.ListSerializer # 可能会查询数据库
|
||||
))
|
||||
if field_need_query_db:
|
||||
return choices # 不返回需要查询数据库的字段选项
|
||||
|
||||
choices = getattr(field, 'choices', [])
|
||||
if not choices:
|
||||
return choices
|
||||
|
||||
if isinstance(choices, dict):
|
||||
# choices 可能是字典、列表或元组
|
||||
choices = list(choices.keys())
|
||||
elif isinstance(choices, (list, tuple)):
|
||||
# choices 可能是 [(value, label), ...] 或 [value, ...]
|
||||
for choice in choices:
|
||||
if isinstance(choice, (list, tuple)) and len(choice) == 2:
|
||||
choices.append(choice[0])
|
||||
else:
|
||||
choices.append(choice)
|
||||
return choices
|
||||
|
||||
def get_field_type(self, field):
|
||||
"""将 Python 字段类型映射到 JSON Schema 类型"""
|
||||
from rest_framework import serializers
|
||||
|
||||
if self.field_is_nested(field):
|
||||
return 'object'
|
||||
|
||||
type_mapping = {
|
||||
serializers.CharField: 'string',
|
||||
serializers.EmailField: 'string',
|
||||
serializers.URLField: 'string',
|
||||
serializers.UUIDField: 'string',
|
||||
serializers.SlugField: 'string',
|
||||
serializers.ChoiceField: 'string',
|
||||
serializers.IntegerField: 'integer',
|
||||
serializers.FloatField: 'number',
|
||||
serializers.DecimalField: 'number',
|
||||
serializers.BooleanField: 'boolean',
|
||||
serializers.DateTimeField: 'string',
|
||||
serializers.DateField: 'string',
|
||||
serializers.TimeField: 'string',
|
||||
serializers.ListField: 'array',
|
||||
serializers.DictField: 'object',
|
||||
serializers.JSONField: 'object',
|
||||
}
|
||||
|
||||
for field_type, json_type in type_mapping.items():
|
||||
if isinstance(field, field_type):
|
||||
return json_type
|
||||
|
||||
return 'string' # 默认类型
|
||||
|
||||
def write_fields_schema(self):
|
||||
fields = self.write_fields
|
||||
if not fields:
|
||||
return {}
|
||||
schema = self.get_fields_schema(fields)
|
||||
return schema
|
||||
|
||||
def required_fields(self):
|
||||
required = []
|
||||
for name, field in self.write_fields.items():
|
||||
if field.required:
|
||||
required.append(name)
|
||||
return required
|
||||
25
utils/generate_view_schema/src/const.py
Normal file
25
utils/generate_view_schema/src/const.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
from django.core.handlers.asgi import ASGIRequest
|
||||
|
||||
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
OUTPUT_FILE_DIR = os.path.join(os.path.join(CURRENT_DIR, 'output'))
|
||||
|
||||
|
||||
# Fake Request object 尝试使用模拟请求
|
||||
scope = {
|
||||
'type': 'http',
|
||||
'method': 'GET',
|
||||
'path': '/',
|
||||
'query_string': b'',
|
||||
'headers': [],
|
||||
}
|
||||
|
||||
async def receive():
|
||||
return {'type': 'http.request', 'body': b''}
|
||||
|
||||
fake_request = ASGIRequest(scope, receive)
|
||||
|
||||
|
||||
def log(message):
|
||||
print(message)
|
||||
24
utils/generate_view_schema/src/main.py
Normal file
24
utils/generate_view_schema/src/main.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import os
|
||||
import sys
|
||||
import django
|
||||
|
||||
# 获取项目根目录(jumpserver 目录)
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
APP_DIR = os.path.join(BASE_DIR, 'apps')
|
||||
|
||||
# 不改变工作目录,直接加入 sys.path
|
||||
sys.path.insert(0, APP_DIR)
|
||||
sys.path.insert(0, BASE_DIR)
|
||||
|
||||
# 设置 Django 环境
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jumpserver.settings")
|
||||
django.setup()
|
||||
|
||||
|
||||
|
||||
from .view_schema_generator import ViewSchemaGenerator
|
||||
from .const import OUTPUT_FILE_DIR
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ViewSchemaGenerator(output_file_dir=OUTPUT_FILE_DIR).generate()
|
||||
19
utils/generate_view_schema/src/url_pattern.py
Normal file
19
utils/generate_view_schema/src/url_pattern.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from .view import CustomView
|
||||
|
||||
__all__ = ['CustomURLPattern']
|
||||
|
||||
|
||||
class CustomURLPattern:
|
||||
def __init__(self, raw, prefix='/'):
|
||||
self.raw = raw
|
||||
self.prefix = prefix
|
||||
self.full_path = f'{self.prefix}{self.raw.pattern}'
|
||||
self.view = CustomView(view_func=self.raw.callback)
|
||||
|
||||
def __str__(self):
|
||||
s = f'{self.full_path} -> {self.view.view_path}'
|
||||
return s
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
156
utils/generate_view_schema/src/view.py
Normal file
156
utils/generate_view_schema/src/view.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.permissions import OperandHolder, AND, OR, NOT
|
||||
from django.views.generic.edit import FormView
|
||||
from django.contrib.auth.mixins import LoginRequiredMixin
|
||||
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rbac.permissions import RBACPermission
|
||||
from common.utils import lazyproperty
|
||||
from .body_fields_generator import (
|
||||
FieldsGenerator, FormFieldsGenerator, SerializerFieldsGenerator
|
||||
)
|
||||
from .const import fake_request
|
||||
|
||||
|
||||
__all__ = ['CustomView']
|
||||
|
||||
|
||||
class CustomView:
|
||||
|
||||
def __init__(self, view_func):
|
||||
self.view_func = view_func
|
||||
|
||||
@property
|
||||
def view_class(self):
|
||||
cls = getattr(self.view_func, 'view_class', None)
|
||||
if not cls:
|
||||
cls = getattr(self.view_func, 'cls', None)
|
||||
return cls
|
||||
|
||||
@property
|
||||
def view_path(self):
|
||||
if self.view_class:
|
||||
v = self.view_class
|
||||
else:
|
||||
v = self.view_func
|
||||
return f'{v.__module__}.{v.__name__}'
|
||||
|
||||
@property
|
||||
def view_type(self):
|
||||
if self.view_class:
|
||||
return 'class'
|
||||
else:
|
||||
return 'function'
|
||||
|
||||
@lazyproperty
|
||||
def fields_generator(self):
|
||||
generator = None
|
||||
if self.view_class:
|
||||
if issubclass(self.view_class, FormView):
|
||||
generator = self.get_form_fields_generator()
|
||||
elif issubclass(self.view_class, APIView):
|
||||
generator= self.get_serializer_fields_generator()
|
||||
else:
|
||||
# 其他类视图暂不处理
|
||||
pass
|
||||
else:
|
||||
# 函数视图暂不处理
|
||||
pass
|
||||
if not generator:
|
||||
generator = FieldsGenerator()
|
||||
return generator
|
||||
|
||||
@property
|
||||
def write_fields_schema(self):
|
||||
return self.fields_generator.write_fields_schema()
|
||||
|
||||
@property
|
||||
def required_fields(self):
|
||||
return self.fields_generator.required_fields()
|
||||
|
||||
def get_form_fields_generator(self):
|
||||
if hasattr(self.view_class, 'get_form_class_comprehensive'):
|
||||
view_instance = self.view_class(request=fake_request)
|
||||
form_class = view_instance.get_form_class_comprehensive()
|
||||
else:
|
||||
form_class = getattr(self.view_class, 'form_class', None)
|
||||
if not form_class:
|
||||
if hasattr(self.view_class, 'get_form_class'):
|
||||
# TODO: 实例化 view 类需要传入 request 参数
|
||||
view_instance = self.view_class(request=fake_request)
|
||||
form_class = view_instance.get_form_class()
|
||||
if form_class:
|
||||
return FormFieldsGenerator(raw_class=form_class, view=self)
|
||||
|
||||
def get_serializer_fields_generator(self):
|
||||
serializer_class = getattr(self.view_class, 'serializer_class', None)
|
||||
if not serializer_class:
|
||||
if hasattr(self.view_class, 'get_serializer_class'):
|
||||
# TODO: 实例化 view 类需要传入 request 参数
|
||||
view_instance = self.view_class(request=fake_request)
|
||||
serializer_class = view_instance.get_serializer_class()
|
||||
if serializer_class:
|
||||
return SerializerFieldsGenerator(raw_class=serializer_class, view=self)
|
||||
|
||||
@property
|
||||
def query_fields_schema(self):
|
||||
return {}
|
||||
|
||||
@lazyproperty
|
||||
def requires_auth(self):
|
||||
if self.view_class:
|
||||
return self.check_view_class_requires_auth()
|
||||
else:
|
||||
return self.check_view_func_requires_auth()
|
||||
|
||||
def check_view_class_requires_auth(self):
|
||||
if issubclass(self.view_class, LoginRequiredMixin):
|
||||
return True
|
||||
|
||||
permission_classes = getattr(self.view_class, 'permission_classes', [])
|
||||
if not permission_classes:
|
||||
return False
|
||||
|
||||
return self.check_permission_classes_requires_auth(permission_classes)
|
||||
|
||||
def check_permission_classes_requires_auth(self, permission_classes, operator=AND):
|
||||
if operator == AND:
|
||||
for pc in permission_classes:
|
||||
if self.check_permission_class_requires_auth(pc):
|
||||
return True
|
||||
return False
|
||||
elif operator == OR:
|
||||
for pc in permission_classes:
|
||||
if not self.check_permission_class_requires_auth(pc):
|
||||
return False
|
||||
return True
|
||||
elif operator == NOT:
|
||||
raise ValueError('NOT operator is not supported in permission_classes')
|
||||
else:
|
||||
return False
|
||||
|
||||
def check_permission_class_requires_auth(self, permission_class):
|
||||
if isinstance(permission_class, OperandHolder):
|
||||
operator = permission_class.operator_class
|
||||
op1_class = permission_class.op1_class
|
||||
op2_class = permission_class.op2_class
|
||||
permission_classes = [op1_class, op2_class]
|
||||
return self.check_permission_classes_requires_auth(permission_classes, operator)
|
||||
else:
|
||||
if issubclass(permission_class, (IsAuthenticated, RBACPermission)):
|
||||
return True
|
||||
if issubclass(permission_class, (AllowAny, )):
|
||||
return False
|
||||
if hasattr(permission_class, '__name__'):
|
||||
if 'Authenticated' in permission_class.__name__:
|
||||
return True
|
||||
if permission_class.__name__.startswith('UserConfirmation'):
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_view_func_requires_auth(self):
|
||||
if hasattr(self.view_func, '__wrapped__'):
|
||||
if hasattr(self.view_func, '__name__'):
|
||||
if 'login_required' in str(self.view_func):
|
||||
return True
|
||||
return False
|
||||
96
utils/generate_view_schema/src/view_schema_generator.py
Normal file
96
utils/generate_view_schema/src/view_schema_generator.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import os
|
||||
import json
|
||||
from django.urls.resolvers import URLPattern, URLResolver
|
||||
from django.urls import get_resolver
|
||||
|
||||
from .url_pattern import CustomURLPattern
|
||||
from .const import log
|
||||
|
||||
|
||||
__all__ = ['ViewSchemaGenerator']
|
||||
|
||||
|
||||
class ViewSchemaGenerator:
|
||||
|
||||
def __init__(self, output_file_dir):
|
||||
os.makedirs(output_file_dir, exist_ok=True)
|
||||
self.output_file_dir = output_file_dir
|
||||
self.resolver = get_resolver()
|
||||
self.url_patterns = self.get_url_patterns()
|
||||
|
||||
def get_url_patterns(self):
|
||||
return self._extract_url_patterns(self.resolver.url_patterns)
|
||||
|
||||
def _extract_url_patterns(self, url_patterns, prefix='/'):
|
||||
url_pattern_objects = []
|
||||
for pattern in url_patterns:
|
||||
if isinstance(pattern, URLResolver):
|
||||
resolver = pattern
|
||||
_prefix = f'{prefix}{resolver.pattern}'
|
||||
patterns = self._extract_url_patterns(resolver.url_patterns, prefix=_prefix)
|
||||
url_pattern_objects.extend(patterns)
|
||||
continue
|
||||
elif isinstance(pattern, URLPattern):
|
||||
p = CustomURLPattern(raw=pattern, prefix=prefix)
|
||||
url_pattern_objects.append(p)
|
||||
else:
|
||||
log(f'Unknown pattern type: {type(pattern)}')
|
||||
return url_pattern_objects
|
||||
|
||||
def generate(self):
|
||||
self.write_url_patterns()
|
||||
self.write_webui_schema()
|
||||
|
||||
def write_webui_schema(self):
|
||||
data = {
|
||||
'GET': {},
|
||||
'POST': {}
|
||||
}
|
||||
post_schema = {}
|
||||
for pattern in self.url_patterns:
|
||||
if pattern.view.requires_auth:
|
||||
continue
|
||||
url = pattern.full_path
|
||||
item = {
|
||||
'allowIf': 'prelogin',
|
||||
'query': {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"additionalProperties": False
|
||||
|
||||
},
|
||||
'body': {
|
||||
'type': 'object',
|
||||
'properties': pattern.view.write_fields_schema,
|
||||
'required': pattern.view.required_fields,
|
||||
'additionalProperties': False
|
||||
}
|
||||
}
|
||||
post_schema[url] = item
|
||||
data['POST'] = post_schema
|
||||
self.write_to_file(data, 'webui_schema.json')
|
||||
|
||||
def write_url_patterns(self):
|
||||
data = []
|
||||
for pattern in self.url_patterns:
|
||||
if pattern.view.requires_auth:
|
||||
continue
|
||||
if pattern.view.view_type == 'function':
|
||||
continue
|
||||
item = {
|
||||
'url': pattern.full_path,
|
||||
'view_path': pattern.view.view_path,
|
||||
'view_requires_auth': pattern.view.requires_auth,
|
||||
'view_type': pattern.view.view_type,
|
||||
}
|
||||
view_write_fields_schema = pattern.view.write_fields_schema
|
||||
if view_write_fields_schema:
|
||||
item['view_write_fields_schema'] = view_write_fields_schema
|
||||
data.append(item)
|
||||
self.write_to_file(data, 'all_url_patterns.json')
|
||||
|
||||
def write_to_file(self, data, filename):
|
||||
file_path = os.path.join(self.output_file_dir, filename)
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
@@ -1,537 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from turtle import up
|
||||
|
||||
import django
|
||||
from django.urls import get_resolver
|
||||
from django.urls.resolvers import URLPattern, URLResolver
|
||||
|
||||
# 获取项目根目录(jumpserver 目录)
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
APP_DIR = os.path.join(BASE_DIR, 'apps')
|
||||
|
||||
# 不改变工作目录,直接加入 sys.path
|
||||
sys.path.insert(0, APP_DIR)
|
||||
sys.path.insert(0, BASE_DIR)
|
||||
|
||||
# 设置 Django 环境
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jumpserver.settings")
|
||||
django.setup()
|
||||
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
OUTPUT_FILE_DIR = os.path.join(CURRENT_DIR, 'output')
|
||||
os.makedirs(OUTPUT_FILE_DIR, exist_ok=True)
|
||||
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.permissions import OperandHolder, AND, OR, NOT
|
||||
from django.views.generic.base import View
|
||||
from django.views.generic.edit import FormView
|
||||
from django.contrib.auth.mixins import LoginRequiredMixin
|
||||
from django.core.handlers.asgi import ASGIRequest
|
||||
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rbac.permissions import RBACPermission
|
||||
from common.utils import lazyproperty
|
||||
# 尝试使用模拟请求
|
||||
scope = {
|
||||
'type': 'http',
|
||||
'method': 'GET',
|
||||
'path': '/',
|
||||
'query_string': b'',
|
||||
'headers': [],
|
||||
}
|
||||
|
||||
async def receive():
|
||||
return {'type': 'http.request', 'body': b''}
|
||||
|
||||
request = ASGIRequest(scope, receive)
|
||||
|
||||
|
||||
def log(message):
|
||||
print(message)
|
||||
|
||||
|
||||
class FieldsGenerator:
|
||||
|
||||
def write_fields_schema(self):
|
||||
return {}
|
||||
|
||||
def required_fields(self):
|
||||
return []
|
||||
|
||||
|
||||
class FormFieldsGenerator(FieldsGenerator):
|
||||
def __init__(self, raw_class, view):
|
||||
self.raw_class = raw_class
|
||||
self.view = view
|
||||
|
||||
|
||||
@lazyproperty
|
||||
def fields(self):
|
||||
return self.raw_class().fields
|
||||
|
||||
def write_fields_schema(self):
|
||||
schema = self.get_fields_schema(self.fields)
|
||||
return schema
|
||||
|
||||
def required_fields(self):
|
||||
fields = [name for name, field in self.fields.items() if field.required]
|
||||
return fields
|
||||
|
||||
def get_fields_schema(self, fields):
|
||||
schemas = {}
|
||||
for name, field in fields.items():
|
||||
schema = {
|
||||
'type': self.get_field_type(field),
|
||||
}
|
||||
description = getattr(field, 'help_text', '')
|
||||
if description:
|
||||
schema['description'] = str(description)
|
||||
schemas[name] = schema
|
||||
return schemas
|
||||
|
||||
def get_field_type(self, field):
|
||||
"""将 Django Form Field 类型映射到 JSON Schema 类型"""
|
||||
from django import forms
|
||||
|
||||
type_mapping = {
|
||||
forms.CharField: 'string',
|
||||
forms.EmailField: 'string',
|
||||
forms.URLField: 'string',
|
||||
forms.SlugField: 'string',
|
||||
forms.UUIDField: 'string',
|
||||
forms.RegexField: 'string',
|
||||
forms.FileField: 'string',
|
||||
forms.ImageField: 'string',
|
||||
forms.FilePathField: 'string',
|
||||
forms.GenericIPAddressField: 'string',
|
||||
forms.IntegerField: 'integer',
|
||||
forms.FloatField: 'number',
|
||||
forms.DecimalField: 'number',
|
||||
forms.BooleanField: 'boolean',
|
||||
forms.NullBooleanField: 'boolean',
|
||||
forms.DateField: 'string',
|
||||
forms.TimeField: 'string',
|
||||
forms.DateTimeField: 'string',
|
||||
forms.DurationField: 'string',
|
||||
forms.MultipleChoiceField: 'array',
|
||||
forms.TypedMultipleChoiceField: 'array',
|
||||
forms.ModelMultipleChoiceField: 'array',
|
||||
forms.ChoiceField: 'string',
|
||||
forms.TypedChoiceField: 'string',
|
||||
forms.ModelChoiceField: 'string',
|
||||
forms.JSONField: 'object',
|
||||
}
|
||||
|
||||
for field_type, json_type in type_mapping.items():
|
||||
if isinstance(field, field_type):
|
||||
return json_type
|
||||
|
||||
return 'string'
|
||||
|
||||
|
||||
class SerializerFieldsGenerator(FieldsGenerator):
|
||||
|
||||
def __init__(self, raw_class, view):
|
||||
self.raw_class = raw_class
|
||||
self.view = view
|
||||
|
||||
@lazyproperty
|
||||
def fields(self):
|
||||
fields = {}
|
||||
try:
|
||||
fields = self.raw_class().fields
|
||||
except Exception as e:
|
||||
if hasattr(self.raw_class, '_declared_fields'):
|
||||
fields = self.raw_class._declared_fields
|
||||
return fields
|
||||
|
||||
@lazyproperty
|
||||
def write_fields(self):
|
||||
fields = {}
|
||||
for name, field in self.fields.items():
|
||||
if field.read_only:
|
||||
continue
|
||||
fields[name] = field
|
||||
return fields
|
||||
|
||||
def get_fields_schema(self, fields):
|
||||
schemas = {}
|
||||
if not hasattr(fields, 'items'):
|
||||
return {}
|
||||
for name, field in fields.items():
|
||||
schema = self.get_field_schema(field)
|
||||
if hasattr(field, 'child'):
|
||||
_fields = field.child
|
||||
_fields_schema = self.get_fields_schema(_fields)
|
||||
schema['properties'] = _fields_schema
|
||||
schemas[name] = schema
|
||||
return schemas
|
||||
|
||||
def get_field_schema(self, field):
|
||||
schema = {
|
||||
'type': self.get_field_type(field),
|
||||
}
|
||||
if description := getattr(field, 'help_text', ''):
|
||||
schema['description'] = str(description)
|
||||
extra_schema = self.get_field_extra_schema(field)
|
||||
schema.update(extra_schema)
|
||||
return schema
|
||||
|
||||
def field_is_nested(self, field):
|
||||
from rest_framework import serializers
|
||||
nested_field_types = (
|
||||
serializers.Serializer,
|
||||
serializers.ModelSerializer,
|
||||
)
|
||||
return isinstance(field, nested_field_types)
|
||||
|
||||
def get_field_extra_schema(self, field):
|
||||
"""获取字段的正则表达式模式"""
|
||||
patterns = {}
|
||||
# 检查 validators 中是否有正则验证器
|
||||
if hasattr(field, 'validators'):
|
||||
for validator in field.validators:
|
||||
if hasattr(validator, 'regex'):
|
||||
patterns['pattern'] = str(validator.regex.pattern)
|
||||
break
|
||||
# 针对特定字段类型添加模式
|
||||
from rest_framework import serializers
|
||||
if isinstance(field, serializers.EmailField):
|
||||
patterns['format'] = 'email'
|
||||
elif isinstance(field, serializers.URLField):
|
||||
patterns['format'] = 'uri'
|
||||
elif isinstance(field, (serializers.DateTimeField, serializers.DateField, serializers.TimeField)):
|
||||
patterns['format'] = 'date-time'
|
||||
|
||||
# 添加长度限制
|
||||
if hasattr(field, 'max_length') and field.max_length:
|
||||
patterns['maxLength'] = field.max_length
|
||||
if hasattr(field, 'min_length') and field.min_length:
|
||||
patterns['minLength'] = field.min_length
|
||||
|
||||
# 添加数值范围
|
||||
if hasattr(field, 'max_value') and field.max_value is not None:
|
||||
patterns['maximum'] = field.max_value
|
||||
if hasattr(field, 'min_value') and field.min_value is not None:
|
||||
patterns['minimum'] = field.min_value
|
||||
|
||||
if choices := self.get_field_choices(field):
|
||||
patterns['enum'] = choices
|
||||
return patterns
|
||||
|
||||
def get_field_choices(self, field):
|
||||
from rest_framework import serializers
|
||||
choices = []
|
||||
field_need_query_db = isinstance(field, (
|
||||
serializers.PrimaryKeyRelatedField, # 会查询数据库
|
||||
serializers.StringRelatedField, # 会查询数据库
|
||||
serializers.SlugRelatedField, # 会查询数据库
|
||||
serializers.HyperlinkedRelatedField, # 会查询数据库
|
||||
serializers.HyperlinkedIdentityField,# 会查询数据库
|
||||
serializers.RelatedField, # 基类,会查询数据库
|
||||
serializers.ManyRelatedField, # 会查询数据库
|
||||
serializers.ListSerializer # 可能会查询数据库
|
||||
))
|
||||
if field_need_query_db:
|
||||
return choices # 不返回需要查询数据库的字段选项
|
||||
|
||||
choices = getattr(field, 'choices', [])
|
||||
if not choices:
|
||||
return choices
|
||||
|
||||
if isinstance(choices, dict):
|
||||
# choices 可能是字典、列表或元组
|
||||
choices = list(choices.keys())
|
||||
elif isinstance(choices, (list, tuple)):
|
||||
# choices 可能是 [(value, label), ...] 或 [value, ...]
|
||||
for choice in choices:
|
||||
if isinstance(choice, (list, tuple)) and len(choice) == 2:
|
||||
choices.append(choice[0])
|
||||
else:
|
||||
choices.append(choice)
|
||||
return choices
|
||||
|
||||
def get_field_type(self, field):
|
||||
"""将 Python 字段类型映射到 JSON Schema 类型"""
|
||||
from rest_framework import serializers
|
||||
|
||||
if self.field_is_nested(field):
|
||||
return 'object'
|
||||
|
||||
type_mapping = {
|
||||
serializers.CharField: 'string',
|
||||
serializers.EmailField: 'string',
|
||||
serializers.URLField: 'string',
|
||||
serializers.UUIDField: 'string',
|
||||
serializers.SlugField: 'string',
|
||||
serializers.ChoiceField: 'string',
|
||||
serializers.IntegerField: 'integer',
|
||||
serializers.FloatField: 'number',
|
||||
serializers.DecimalField: 'number',
|
||||
serializers.BooleanField: 'boolean',
|
||||
serializers.DateTimeField: 'string',
|
||||
serializers.DateField: 'string',
|
||||
serializers.TimeField: 'string',
|
||||
serializers.ListField: 'array',
|
||||
serializers.DictField: 'object',
|
||||
serializers.JSONField: 'object',
|
||||
}
|
||||
|
||||
for field_type, json_type in type_mapping.items():
|
||||
if isinstance(field, field_type):
|
||||
return json_type
|
||||
|
||||
return 'string' # 默认类型
|
||||
|
||||
def write_fields_schema(self):
|
||||
fields = self.write_fields
|
||||
if not fields:
|
||||
return {}
|
||||
schema = self.get_fields_schema(fields)
|
||||
return schema
|
||||
|
||||
def required_fields(self):
|
||||
required = []
|
||||
for name, field in self.write_fields.items():
|
||||
if field.required:
|
||||
required.append(name)
|
||||
return required
|
||||
|
||||
|
||||
class CustomView:
|
||||
|
||||
def __init__(self, view_func):
|
||||
self.view_func = view_func
|
||||
|
||||
@property
|
||||
def view_class(self):
|
||||
cls = getattr(self.view_func, 'view_class', None)
|
||||
if not cls:
|
||||
cls = getattr(self.view_func, 'cls', None)
|
||||
return cls
|
||||
|
||||
@property
|
||||
def view_path(self):
|
||||
if self.view_class:
|
||||
v = self.view_class
|
||||
else:
|
||||
v = self.view_func
|
||||
return f'{v.__module__}.{v.__name__}'
|
||||
|
||||
@property
|
||||
def view_type(self):
|
||||
if self.view_class:
|
||||
return 'class'
|
||||
else:
|
||||
return 'function'
|
||||
|
||||
@lazyproperty
|
||||
def fields_generator(self):
|
||||
generator = None
|
||||
if self.view_class:
|
||||
if issubclass(self.view_class, FormView):
|
||||
generator = self.get_form_fields_generator()
|
||||
if issubclass(self.view_class, APIView):
|
||||
generator= self.get_serializer_fields_generator()
|
||||
if not generator:
|
||||
generator = FieldsGenerator()
|
||||
return generator
|
||||
|
||||
@property
|
||||
def write_fields_schema(self):
|
||||
return self.fields_generator.write_fields_schema()
|
||||
|
||||
@property
|
||||
def required_fields(self):
|
||||
return self.fields_generator.required_fields()
|
||||
|
||||
def get_form_fields_generator(self):
|
||||
if hasattr(self.view_class, 'get_comprehensive_form_class'):
|
||||
view_instance = self.view_class(request=request)
|
||||
form_class = view_instance.get_comprehensive_form_class()
|
||||
else:
|
||||
form_class = getattr(self.view_class, 'form_class', None)
|
||||
if not form_class:
|
||||
if hasattr(self.view_class, 'get_form_class'):
|
||||
# TODO: 实例化 view 类需要传入 request 参数
|
||||
view_instance = self.view_class(request=request)
|
||||
form_class = view_instance.get_form_class()
|
||||
if form_class:
|
||||
return FormFieldsGenerator(raw_class=form_class, view=self)
|
||||
|
||||
def get_serializer_fields_generator(self):
|
||||
serializer_class = getattr(self.view_class, 'serializer_class', None)
|
||||
if not serializer_class:
|
||||
if hasattr(self.view_class, 'get_serializer_class'):
|
||||
# TODO: 实例化 view 类需要传入 request 参数
|
||||
view_instance = self.view_class(request=request)
|
||||
serializer_class = view_instance.get_serializer_class()
|
||||
if serializer_class:
|
||||
return SerializerFieldsGenerator(raw_class=serializer_class, view=self)
|
||||
|
||||
@property
|
||||
def query_fields_schema(self):
|
||||
return {}
|
||||
|
||||
@lazyproperty
|
||||
def requires_auth(self):
|
||||
if self.view_class:
|
||||
return self.check_view_class_requires_auth()
|
||||
else:
|
||||
return self.check_view_func_requires_auth()
|
||||
|
||||
def check_view_class_requires_auth(self):
|
||||
if issubclass(self.view_class, LoginRequiredMixin):
|
||||
return True
|
||||
|
||||
permission_classes = getattr(self.view_class, 'permission_classes', [])
|
||||
if not permission_classes:
|
||||
return False
|
||||
|
||||
return self.check_permission_classes_requires_auth(permission_classes)
|
||||
|
||||
def check_permission_classes_requires_auth(self, permission_classes, operator=AND):
|
||||
if operator == AND:
|
||||
for pc in permission_classes:
|
||||
if self.check_permission_class_requires_auth(pc):
|
||||
return True
|
||||
return False
|
||||
elif operator == OR:
|
||||
for pc in permission_classes:
|
||||
if not self.check_permission_class_requires_auth(pc):
|
||||
return False
|
||||
return True
|
||||
elif operator == NOT:
|
||||
raise ValueError('NOT operator is not supported in permission_classes')
|
||||
else:
|
||||
return False
|
||||
|
||||
def check_permission_class_requires_auth(self, permission_class):
|
||||
if isinstance(permission_class, OperandHolder):
|
||||
operator = permission_class.operator_class
|
||||
op1_class = permission_class.op1_class
|
||||
op2_class = permission_class.op2_class
|
||||
permission_classes = [op1_class, op2_class]
|
||||
return self.check_permission_classes_requires_auth(permission_classes, operator)
|
||||
else:
|
||||
if issubclass(permission_class, (IsAuthenticated, RBACPermission)):
|
||||
return True
|
||||
if issubclass(permission_class, (AllowAny, )):
|
||||
return False
|
||||
if hasattr(permission_class, '__name__'):
|
||||
if 'Authenticated' in permission_class.__name__:
|
||||
return True
|
||||
if permission_class.__name__.startswith('UserConfirmation'):
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_view_func_requires_auth(self):
|
||||
if hasattr(self.view_func, '__wrapped__'):
|
||||
if hasattr(self.view_func, '__name__'):
|
||||
if 'login_required' in str(self.view_func):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class CustomURLPattern:
|
||||
def __init__(self, raw, prefix='/'):
|
||||
self.raw = raw
|
||||
self.prefix = prefix
|
||||
self.full_path = f'{self.prefix}{self.raw.pattern}'
|
||||
self.view = CustomView(view_func=self.raw.callback)
|
||||
|
||||
def __str__(self):
|
||||
s = f'{self.full_path} -> {self.view.view_path}'
|
||||
return s
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class ViewSchemaGenerator:
|
||||
|
||||
def __init__(self):
|
||||
self.resolver = get_resolver()
|
||||
self.url_patterns = self.get_url_patterns()
|
||||
|
||||
def get_url_patterns(self):
|
||||
return self._extract_url_patterns(self.resolver.url_patterns)
|
||||
|
||||
def _extract_url_patterns(self, url_patterns, prefix='/'):
|
||||
url_pattern_objects = []
|
||||
for pattern in url_patterns:
|
||||
if isinstance(pattern, URLResolver):
|
||||
resolver = pattern
|
||||
_prefix = f'{prefix}{resolver.pattern}'
|
||||
patterns = self._extract_url_patterns(resolver.url_patterns, prefix=_prefix)
|
||||
url_pattern_objects.extend(patterns)
|
||||
continue
|
||||
elif isinstance(pattern, URLPattern):
|
||||
p = CustomURLPattern(raw=pattern, prefix=prefix)
|
||||
url_pattern_objects.append(p)
|
||||
else:
|
||||
log(f'Unknown pattern type: {type(pattern)}')
|
||||
return url_pattern_objects
|
||||
|
||||
def generate(self):
|
||||
self.write_url_patterns()
|
||||
self.write_webui_schema()
|
||||
|
||||
def write_webui_schema(self):
|
||||
data = {
|
||||
'GET': {},
|
||||
'POST': {}
|
||||
}
|
||||
post_schema = {}
|
||||
for pattern in self.url_patterns:
|
||||
if pattern.view.requires_auth:
|
||||
continue
|
||||
url = pattern.full_path
|
||||
item = {
|
||||
'allowIf': 'prelogin',
|
||||
'query': {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"additionalProperties": False
|
||||
|
||||
},
|
||||
'body': {
|
||||
'type': 'object',
|
||||
'properties': pattern.view.write_fields_schema,
|
||||
'required': pattern.view.required_fields,
|
||||
'additionalProperties': False
|
||||
}
|
||||
}
|
||||
post_schema[url] = item
|
||||
data['POST'] = post_schema
|
||||
self.write_to_file(data, 'webui_schema.json')
|
||||
|
||||
def write_url_patterns(self):
|
||||
data = []
|
||||
for pattern in self.url_patterns:
|
||||
if pattern.view.requires_auth:
|
||||
continue
|
||||
if pattern.view.view_type == 'function':
|
||||
continue
|
||||
item = {
|
||||
'url': pattern.full_path,
|
||||
'view_path': pattern.view.view_path,
|
||||
'view_requires_auth': pattern.view.requires_auth,
|
||||
'view_type': pattern.view.view_type,
|
||||
}
|
||||
view_write_fields_schema = pattern.view.write_fields_schema
|
||||
if view_write_fields_schema:
|
||||
item['view_write_fields_schema'] = view_write_fields_schema
|
||||
data.append(item)
|
||||
self.write_to_file(data, 'all_url_patterns.json')
|
||||
|
||||
def write_to_file(self, data, filename):
|
||||
file_path = os.path.join(OUTPUT_FILE_DIR, filename)
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
ViewSchemaGenerator().generate()
|
||||
Reference in New Issue
Block a user