feat: Endpoint 支持 oracle 版本 (#8585)

* feat: Endpoint 支持 oracle 版本

* feat: Endpoint 支持 oracle 版本

* feat: Endpoint 支持 oracle 版本

Co-authored-by: Jiangjie.Bai <bugatti_it@163.com>
This commit is contained in:
fit2bot
2022-07-13 16:29:05 +08:00
committed by GitHub
parent 2abca39597
commit ce2f6fdc84
11 changed files with 248 additions and 126 deletions

View File

@@ -1,6 +1,7 @@
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework import status
from rest_framework.request import Request
from common.drf.api import JMSBulkModelViewSet
from django.utils.translation import ugettext_lazy as _
from django.shortcuts import get_object_or_404
@@ -18,39 +19,42 @@ __all__ = ['EndpointViewSet', 'EndpointRuleViewSet']
class SmartEndpointViewMixin:
get_serializer: callable
request: Request
# View 处理过程中用的属性
target_instance: None
target_protocol: None
@action(methods=['get'], detail=False, permission_classes=[IsValidUser], url_path='smart')
def smart(self, request, *args, **kwargs):
protocol = request.GET.get('protocol')
if not protocol:
self.target_instance = self.get_target_instance()
self.target_protocol = self.get_target_protocol()
if not self.target_protocol:
error = _('Not found protocol query params')
return Response(data={'error': error}, status=status.HTTP_404_NOT_FOUND)
endpoint = self.match_endpoint(request, protocol)
endpoint = self.match_endpoint()
serializer = self.get_serializer(endpoint)
return Response(serializer.data)
def match_endpoint(self, request, protocol):
instance = self.get_target_instance(request)
endpoint = self.match_endpoint_by_label(instance, protocol)
def match_endpoint(self):
endpoint = self.match_endpoint_by_label()
if not endpoint:
endpoint = self.match_endpoint_by_target_ip(request, instance, protocol)
endpoint = self.match_endpoint_by_target_ip()
return endpoint
@staticmethod
def match_endpoint_by_label(instance, protocol):
return Endpoint.match_by_instance_label(instance, protocol)
def match_endpoint_by_label(self):
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol)
@staticmethod
def match_endpoint_by_target_ip(request, instance, protocol):
def match_endpoint_by_target_ip(self):
# 用来方便测试
target_ip = request.GET.get('target_ip', '')
if not target_ip and callable(getattr(instance, 'get_target_ip', None)):
target_ip = instance.get_target_ip()
endpoint = EndpointRule.match_endpoint(target_ip, protocol, request)
target_ip = self.request.GET.get('target_ip', '')
if not target_ip and callable(getattr(self.target_instance, 'get_target_ip', None)):
target_ip = self.target_instance.get_target_ip()
endpoint = EndpointRule.match_endpoint(target_ip, self.target_protocol, self.request)
return endpoint
@staticmethod
def get_target_instance(request):
def get_target_instance(self):
request = self.request
asset_id = request.GET.get('asset_id')
app_id = request.GET.get('app_id')
session_id = request.GET.get('session_id')
@@ -77,6 +81,14 @@ class SmartEndpointViewMixin:
instance = get_object_or_404(model, pk=pk)
return instance
def get_target_protocol(self):
protocol = None
if isinstance(self.target_instance, Application) and self.target_instance.is_type(Application.APP_TYPE.oracle):
protocol = self.target_instance.get_target_protocol_for_oracle()
if not protocol:
protocol = self.request.GET.get('protocol')
return protocol
class EndpointViewSet(SmartEndpointViewMixin, JMSBulkModelViewSet):
filterset_fields = ('name', 'host')