feat: 修改 Endpoint 获取 Manugs DB listen port 的逻辑

This commit is contained in:
Jiangjie.Bai
2022-09-22 15:52:47 +08:00
parent b8ec60dea1
commit 57e12256e7
8 changed files with 48 additions and 67 deletions

View File

@@ -1,27 +1,23 @@
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.core.validators import MinValueValidator, MaxValueValidator
from applications.models import Application
from applications.utils import db_port_manager
from common.db.models import JMSModel
from common.db.fields import PortField
from common.utils.ip import contains_ip
from common.exceptions import JMSException
class Endpoint(JMSModel):
name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True)
host = models.CharField(max_length=256, blank=True, verbose_name=_('Host'))
# disabled value=0
# value=0 表示 disabled
https_port = PortField(default=443, verbose_name=_('HTTPS Port'))
http_port = PortField(default=80, verbose_name=_('HTTP Port'))
ssh_port = PortField(default=2222, verbose_name=_('SSH Port'))
rdp_port = PortField(default=3389, verbose_name=_('RDP Port'))
# Todo: Delete
mysql_port = PortField(default=33060, verbose_name=_('MySQL Port'))
mariadb_port = PortField(default=33061, verbose_name=_('MariaDB Port'))
postgresql_port = PortField(default=54320, verbose_name=_('PostgreSQL Port'))
redis_port = PortField(default=63790, verbose_name=_('Redis Port'))
oracle_11g_port = PortField(default=15211, verbose_name=_('Oracle 11g Port'))
oracle_12c_port = PortField(default=15212, verbose_name=_('Oracle 12c Port'))
comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
default_id = '00000000-0000-0000-0000-000000000001'
@@ -33,12 +29,18 @@ class Endpoint(JMSModel):
def __str__(self):
return self.name
def get_port(self, protocol):
return getattr(self, f'{protocol}_port', 0)
def get_oracle_port(self, version):
protocol = f'oracle_{version}'
return self.get_port(protocol)
def get_port(self, target_instance, protocol):
if protocol in ['https', 'http', 'ssh', 'rdp']:
port = getattr(self, f'{protocol}_port', 0)
elif isinstance(target_instance, Application) and target_instance.category_db:
port = db_port_manager.get_port_by_db(target_instance)
if port is None:
error = 'No application port is matched, application id: {}' \
''.format(target_instance.id)
raise JMSException(error)
else:
port = 0
return port
def is_default(self):
return str(self.id) == self.default_id
@@ -48,10 +50,10 @@ class Endpoint(JMSModel):
return
return super().delete(using, keep_parents)
def is_valid_for(self, protocol):
def is_valid_for(self, target_instance, protocol):
if self.is_default():
return True
if self.host and self.get_port(protocol) != 0:
if self.host and self.get_port(target_instance, protocol) != 0:
return True
return False
@@ -105,19 +107,19 @@ class EndpointRule(JMSModel):
return f'{self.name}({self.priority})'
@classmethod
def match(cls, target_ip, protocol):
def match(cls, target_instance, target_ip, protocol):
for endpoint_rule in cls.objects.all().prefetch_related('endpoint'):
if not contains_ip(target_ip, endpoint_rule.ip_group):
continue
if not endpoint_rule.endpoint:
continue
if not endpoint_rule.endpoint.is_valid_for(protocol):
if not endpoint_rule.endpoint.is_valid_for(target_instance, protocol):
continue
return endpoint_rule
@classmethod
def match_endpoint(cls, target_ip, protocol, request=None):
endpoint_rule = cls.match(target_ip, protocol)
def match_endpoint(cls, target_instance, target_ip, protocol, request=None):
endpoint_rule = cls.match(target_instance, target_ip, protocol)
if endpoint_rule:
endpoint = endpoint_rule.endpoint
else: