@@ -1,15 +1,16 @@
import abc
import os
import abc
import json
import time
import base64
import urllib . parse
from django . http import HttpResponse
from django . shortcuts import get_object_or_404
from rest_framework . request import Request
from rest_framework import status
from rest_framework . exceptions import PermissionDenied
from rest_framework . decorators import action
from rest_framework . response import Response
from rest_framework import status
from rest_framework . request import Request
from common . drf . api import JMSModelViewSet
@@ -25,76 +26,12 @@ from ..models import ConnectionToken
__all__ = [ ' ConnectionTokenViewSet ' , ' SuperConnectionTokenViewSet ' ]
# ExtraActionApiMixin
class ConnectionTokenMixin :
class RDPFileClientProtocolURLMixin :
request : Request
@staticmethod
def check_token_valid ( token : ConnectionToken ) :
is_valid , error = token . check_valid ( )
if not is_valid :
raise PermissionDenied ( error )
@abc.abstractmethod
def get_request_resource_user ( self , serializer ) :
raise NotImplementedError
def get_request_resources ( self , serializer ) :
user = self . get_request_resource_user ( serializer )
asset = serializer . validated_data . get ( ' asset ' )
account = serializer . validated_data . get ( ' account ' )
return user , asset , account
@staticmethod
def check_user_has_resource_permission ( user , asset , account ) :
from perms . utils . account import PermAccountUtil
if not asset or not user :
error = ' '
raise PermissionDenied ( error )
actions , expire_at = PermAccountUtil ( ) . validate_permission (
user , asset , account_username = account
)
if not actions :
error = ' '
raise PermissionDenied ( error )
if expire_at < time . time ( ) :
error = ' '
raise PermissionDenied ( error )
def get_smart_endpoint ( self , protocol , asset = None , application = None ) :
if asset :
target_ip = asset . get_target_ip ( )
elif application :
target_ip = application . get_target_ip ( )
else :
target_ip = ' '
endpoint = EndpointRule . match_endpoint ( target_ip , protocol , self . request )
return endpoint
@staticmethod
def parse_env_bool ( env_key , env_default , true_value , false_value ) :
return true_value if is_true ( os . getenv ( env_key , env_default ) ) else false_value
def get_client_protocol_data ( self , token : ConnectionToken ) :
protocol = token . protocol
username = token . user . username
rdp_config = ssh_token = ' '
if protocol == ' rdp ' :
filename , rdp_config = self . get_rdp_file_info ( token )
elif protocol == ' ssh ' :
filename , ssh_token = self . get_ssh_token ( token )
else :
raise ValueError ( ' Protocol not support: {} ' . format ( protocol ) )
return {
" filename " : filename ,
" protocol " : protocol ,
" username " : username ,
" token " : ssh_token ,
" config " : rdp_config
}
get_serializer : callable
def get_rdp_file_info ( self , token : ConnectionToken ) :
rdp_options = {
@@ -189,6 +126,29 @@ class ConnectionTokenMixin:
filename = urllib . parse . quote ( filename )
return filename
@staticmethod
def parse_env_bool ( env_key , env_default , true_value , false_value ) :
return true_value if is_true ( os . getenv ( env_key , env_default ) ) else false_value
def get_client_protocol_data ( self , token : ConnectionToken ) :
protocol = token . protocol
username = token . user . username
rdp_config = ssh_token = ' '
if protocol == ' rdp ' :
filename , rdp_config = self . get_rdp_file_info ( token )
elif protocol == ' ssh ' :
filename , ssh_token = self . get_ssh_token ( token )
else :
raise ValueError ( ' Protocol not support: {} ' . format ( protocol ) )
return {
" filename " : filename ,
" protocol " : protocol ,
" username " : username ,
" token " : ssh_token ,
" config " : rdp_config
}
def get_ssh_token ( self , token : ConnectionToken ) :
if token . asset :
name = token . asset . name
@@ -207,8 +167,79 @@ class ConnectionTokenMixin:
token = json . dumps ( data )
return filename , token
def get_smart_endpoint ( self , protocol , asset = None ) :
target_ip = asset . get_target_ip ( ) if asset else ' '
endpoint = EndpointRule . match_endpoint ( target_ip , protocol , self . request )
return endpoint
class ConnectionTokenViewSet ( ConnectionTokenMixin , RootOrgViewMixin , JMSModelViewSet ) :
class ExtraActionApiMixin ( RDPFileClientProtocolURLMixin ) :
request : Request
get_object : callable
get_serializer : callable
perform_create : callable
check_token_permission : callable
create_connection_token : callable
@action ( methods = [ ' POST ' ] , detail = False , url_path = ' secret-info/detail ' )
def get_secret_detail ( self , request , * args , * * kwargs ) :
""" 非常重要的 api, 在逻辑层再判断一下 rbac 权限, 双重保险 """
rbac_perm = ' authentication.view_connectiontokensecret '
if not request . user . has_perm ( rbac_perm ) :
raise PermissionDenied ( ' Not allow to view secret ' )
token_id = request . data . get ( ' token ' ) or ' '
token = get_object_or_404 ( ConnectionToken , pk = token_id )
self . check_token_permission ( token )
serializer = self . get_serializer ( instance = token )
return Response ( serializer . data , status = status . HTTP_200_OK )
@action ( methods = [ ' POST ' , ' GET ' ] , detail = False , url_path = ' rdp/file ' )
def get_rdp_file ( self , request , * args , * * kwargs ) :
token = self . create_connection_token ( )
self . check_token_permission ( token )
filename , content = self . get_rdp_file_info ( token )
filename = ' {} .rdp ' . format ( filename )
response = HttpResponse ( content , content_type = ' application/octet-stream ' )
response [ ' Content-Disposition ' ] = ' attachment; filename*=UTF-8 \' \' %s ' % filename
return response
@action ( methods = [ ' POST ' , ' GET ' ] , detail = False , url_path = ' client-url ' )
def get_client_protocol_url ( self , request , * args , * * kwargs ) :
token = self . create_connection_token ( )
self . check_token_permission ( token )
try :
protocol_data = self . get_client_protocol_data ( token )
except ValueError as e :
return Response ( data = { ' error ' : str ( e ) } , status = status . HTTP_400_BAD_REQUEST )
protocol_data = json . dumps ( protocol_data ) . encode ( )
protocol_data = base64 . b64encode ( protocol_data ) . decode ( )
data = {
' url ' : ' jms:// {} ' . format ( protocol_data )
}
return Response ( data = data )
@action ( methods = [ ' PATCH ' ] , detail = True )
def expire ( self , request , * args , * * kwargs ) :
instance = self . get_object ( )
instance . expire ( )
return Response ( status = status . HTTP_204_NO_CONTENT )
@staticmethod
def check_token_permission ( token : ConnectionToken ) :
is_valid , error = token . check_permission ( )
if not is_valid :
raise PermissionDenied ( error )
def create_connection_token ( self ) :
data = self . request . query_params if self . request . method == ' GET ' else self . request . data
serializer = self . get_serializer ( data = data )
serializer . is_valid ( raise_exception = True )
self . perform_create ( serializer )
token : ConnectionToken = serializer . instance
return token
class ConnectionTokenViewSet ( ExtraActionApiMixin , RootOrgViewMixin , JMSModelViewSet ) :
filterset_fields = (
' user_display ' , ' asset_display '
)
@@ -231,72 +262,29 @@ class ConnectionTokenViewSet(ConnectionTokenMixin, RootOrgViewMixin, JMSModelVie
def get_queryset ( self ) :
return ConnectionToken . objects . filter ( user = self . request . user )
def get_request_resource_ user ( self , serializer ) :
def get_user ( self , serializer ) :
return self . request . user
def get_object ( self ) :
if self . request . user . is_service_account :
# TODO: 组件获取 token 详情,将来放在 Super-connection-token API 中
obj = get_object_or_404 ( ConnectionToken , pk = self . kwargs . get ( ' pk ' ) )
else :
obj = super ( ConnectionTokenViewSet , self ) . get_object ( )
return obj
def create_connection_token ( self ) :
data = self . request . query_params if self . request . method == ' GET ' else self . request . data
serializer = self . get_serializer ( data = data )
serializer . is_valid ( raise_exception = True )
self . perform_create ( serializer )
token : ConnectionToken = serializer . instance
return token
def perform_create ( self , serializer ) :
user , asset , account = self . get_request_resources ( serializer )
self . check_user_has_resource_permission ( user , asset , account )
user = self . get_user ( serializer )
asset = serializer . validated_data . get ( ' asset ' )
account_username = serializer . validated_data . get ( ' account_username ' )
self . validate_asset_permission ( user , asset , account_username )
return super ( ConnectionTokenViewSet , self ) . perform_create ( serializer )
@action ( methods = [ ' POST ' ] , detail = False , url_path = ' secret-info/detail ' )
def get_secret_detail ( self , request , * args , * * kwargs ) :
# 非常重要的 api, 在逻辑层再判断一下, 双重保险
perm_required = ' authentication.view_connectiontokensecret '
if not request . user . has_perm ( perm_required ) :
raise PermissionDenied ( ' Not allow to view secret ' )
token_id = request . data . get ( ' token ' ) or ' '
token = get_object_or_404 ( ConnectionToken , pk = token_id )
self . check_token_valid ( token )
serializer = self . get_serializer ( instance = token )
return Response ( serializer . data , status = status . HTTP_200_OK )
@staticmethod
def validate_asset_permission ( user , asset , account_username ) :
from perms . utils . account import PermAccountUtil
actions , expire_at = PermAccountUtil ( ) . validate_permission ( user , asset , account_username )
if not actions :
error = ' '
raise PermissionDenied ( error )
if expire_at < time . time ( ) :
error = ' '
raise PermissionDenied ( error )
@action ( methods = [ ' POST ' , ' GET ' ] , detail = False , url_path = ' rdp/file ' )
def get_rdp_file ( self , request , * args , * * kwargs ) :
token = self . create_connection_token ( )
self . check_token_valid ( token )
filename , content = self . get_rdp_file_info ( token )
filename = ' {} .rdp ' . format ( filename )
response = HttpResponse ( content , content_type = ' application/octet-stream ' )
response [ ' Content-Disposition ' ] = ' attachment; filename*=UTF-8 \' \' %s ' % filename
return response
@action ( methods = [ ' POST ' , ' GET ' ] , detail = False , url_path = ' client-url ' )
def get_client_protocol_url ( self , request , * args , * * kwargs ) :
token = self . create_connection_token ( )
self . check_token_valid ( token )
try :
protocol_data = self . get_client_protocol_data ( token )
except ValueError as e :
return Response ( data = { ' error ' : str ( e ) } , status = status . HTTP_400_BAD_REQUEST )
protocol_data = json . dumps ( protocol_data ) . encode ( )
protocol_data = base64 . b64encode ( protocol_data ) . decode ( )
data = {
' url ' : ' jms:// {} ' . format ( protocol_data )
}
return Response ( data = data )
@action ( methods = [ ' PATCH ' ] , detail = True )
def expire ( self , request , * args , * * kwargs ) :
instance = self . get_object ( )
instance . expire ( )
return Response ( status = status . HTTP_204_NO_CONTENT )
# SuperConnectionToken
class SuperConnectionTokenViewSet ( ConnectionTokenViewSet ) :
@@ -308,7 +296,10 @@ class SuperConnectionTokenViewSet(ConnectionTokenViewSet):
' renewal ' : ' authentication.add_superconnectiontoken '
}
def get_re quest_resource_user ( self , serializer ) :
def get_queryset ( self ) :
return ConnectionToken . objects . all ( )
def get_user ( self , serializer ) :
return serializer . validated_data . get ( ' user ' )
@action ( methods = [ ' PATCH ' ] , detail = False )