From 3853d0bcc6f3c4a5297d57250d1e545b70ea3533 Mon Sep 17 00:00:00 2001 From: wangruidong <940853815@qq.com> Date: Thu, 18 Jan 2024 18:35:32 +0800 Subject: [PATCH] =?UTF-8?q?fix:=E7=BB=91=E5=AE=9A=E7=9A=84=E7=AB=AF?= =?UTF-8?q?=E7=82=B9Default=E4=B8=8B=E8=BD=BDRDP=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E5=9C=B0=E5=9D=80=E6=98=AF=E7=A9=BA=E7=9A=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/authentication/api/connection_token.py | 2 +- apps/terminal/api/component/endpoint.py | 2 +- apps/terminal/models/component/endpoint.py | 26 +++++++++++++-------- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/apps/authentication/api/connection_token.py b/apps/authentication/api/connection_token.py index 076767946..b5bafac32 100644 --- a/apps/authentication/api/connection_token.py +++ b/apps/authentication/api/connection_token.py @@ -205,7 +205,7 @@ class RDPFileClientProtocolURLMixin: return data def get_smart_endpoint(self, protocol, asset=None): - endpoint = Endpoint.match_by_instance_label(asset, protocol) + endpoint = Endpoint.match_by_instance_label(asset, protocol, self.request) if not endpoint: target_ip = asset.get_target_ip() if asset else '' endpoint = EndpointRule.match_endpoint( diff --git a/apps/terminal/api/component/endpoint.py b/apps/terminal/api/component/endpoint.py index b40aba9aa..9b573fc58 100644 --- a/apps/terminal/api/component/endpoint.py +++ b/apps/terminal/api/component/endpoint.py @@ -42,7 +42,7 @@ class SmartEndpointViewMixin: return endpoint def match_endpoint_by_label(self): - return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol) + return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol, self.request) def match_endpoint_by_target_ip(self): target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数,用来方便测试 diff --git a/apps/terminal/models/component/endpoint.py b/apps/terminal/models/component/endpoint.py index d9d4cfab8..cdb9b0135 100644 --- a/apps/terminal/models/component/endpoint.py +++ b/apps/terminal/models/component/endpoint.py @@ -75,7 +75,20 @@ class Endpoint(JMSBaseModel): return endpoint @classmethod - def match_by_instance_label(cls, instance, protocol): + def handle_endpoint_host(cls, endpoint, request=None): + if not endpoint.host and request: + # 动态添加 current request host + host_port = request.get_host() + # IPv6 + if host_port.startswith('['): + host = host_port.split(']:')[0].rstrip(']') + ']' + else: + host = host_port.split(':')[0] + endpoint.host = host + return endpoint + + @classmethod + def match_by_instance_label(cls, instance, protocol, request=None): from assets.models import Asset from terminal.models import Session if isinstance(instance, Session): @@ -88,6 +101,7 @@ class Endpoint(JMSBaseModel): endpoints = cls.objects.filter(name__in=list(values)).order_by('-date_updated') for endpoint in endpoints: if endpoint.is_valid_for(instance, protocol): + endpoint = cls.handle_endpoint_host(endpoint, request) return endpoint @@ -130,13 +144,5 @@ class EndpointRule(JMSBaseModel): endpoint = endpoint_rule.endpoint else: endpoint = Endpoint.get_or_create_default(request) - if not endpoint.host and request: - # 动态添加 current request host - host_port = request.get_host() - # IPv6 - if host_port.startswith('['): - host = host_port.split(']:')[0].rstrip(']') + ']' - else: - host = host_port.split(':')[0] - endpoint.host = host + endpoint = Endpoint.handle_endpoint_host(endpoint, request) return endpoint