Compare commits

..

15 Commits

Author SHA1 Message Date
ibuler
7c08582e7d Merge branch 'master' into v2.0 2020-07-08 16:14:19 +08:00
ibuler
3555fe1c20 Merge branch 'v2.0' of github.com:jumpserver/jumpserver into v2.0 2020-07-08 16:14:06 +08:00
ibuler
4a1045ba81 ci(build&release): 自动化构建发布 2020-07-08 15:58:57 +08:00
ibuler
7fc3218010 添加example api 2020-07-08 15:58:57 +08:00
ibuler
0dd4c8adc2 Merge branch 'v2.0' of github.com:jumpserver/jumpserver into v2.0 2020-07-08 15:48:05 +08:00
xinwen
a9626c2b39 Merge pull request #4223 from jumpserver/update-add-filter
[Update] assets/gathered_user 添加过滤字段
2020-07-01 17:32:58 +08:00
xinwen
b265fad50f [Update] assets/gathered_user 添加过滤字段 2020-07-01 17:21:00 +08:00
BaiJiangJie
f40b50dddd Merge pull request #4209 from jumpserver/ftp_log_order
fix:修改 ftp 日志按开始日期排序
2020-07-01 15:02:44 +08:00
xinwen
0d6255b07f Merge pull request #4218 from jumpserver/system-user-add-filter
System user add filter
2020-07-01 15:01:07 +08:00
xinwen
06b02cfcfd [Update] 系统用户添加过滤字段 2020-07-01 14:57:52 +08:00
jym503558564
a7e1ed6f03 fix:修改 ftp 日志按开始日期排序 2020-07-01 11:04:42 +08:00
xinwen
1c6f89519a Merge pull request #4201 from jumpserver/fix-i18n
Fix i18n
2020-06-30 14:21:44 +08:00
xinwen
24d093747f [Fix] X-Pack/云管中心 i18n 2020-06-30 14:18:21 +08:00
ibuler
7548bb8976 Merge branch 'dev' of github.com:jumpserver/jumpserver into dev 2020-06-17 11:33:57 +08:00
ibuler
8ffa5e0aec 添加example api 2020-06-16 20:11:50 +08:00
344 changed files with 3417 additions and 11630 deletions

17
.github/ISSUE_TEMPLATE.md vendored Normal file
View File

@@ -0,0 +1,17 @@
[简述你的问题]
##### 使用版本
[请提供你使用的JumpServer版本 如 2.0.1 注: 1.4及以下版本不再提供支持]
##### 问题复现步骤
1. [步骤1]
2. [步骤2]
##### 具体表现[截图可能会更好些,最好能截全]
##### 其他
[注:] 完成后请关闭 issue

View File

@@ -1,10 +0,0 @@
---
name: 需求建议
about: 提出针对本项目的想法和建议
title: "[Feature] "
labels: 类型:需求
assignees: ibuler
---
**请描述您的需求或者改进建议.**

View File

@@ -1,22 +0,0 @@
---
name: Bug 提交
about: 提交产品缺陷帮助我们更好的改进
title: "[Bug] "
labels: 类型:bug
assignees: wojiushixiaobai
---
**JumpServer 版本(v1.5.9以下不再支持)**
**浏览器版本**
**Bug 描述**
**Bug 重现步骤(有截图更好)**
1.
2.
3.

View File

@@ -1,10 +0,0 @@
---
name: 问题咨询
about: 提出针对本项目安装部署、使用及其他方面的相关问题
title: "[Question] "
labels: 类型:提问
assignees: wojiushixiaobai
---
**请描述您的问题.**

View File

@@ -1,12 +0,0 @@
on: [push, pull_request, release]
name: JumpServer repos generic handler
jobs:
generic_handler:
name: Run generic handler
runs-on: ubuntu-latest
steps:
- uses: jumpserver/action-generic-handler@master
env:
GITHUB_TOKEN: ${{ secrets.PRIVATE_TOKEN }}

View File

@@ -1,34 +1,19 @@
FROM registry.fit2cloud.com/public/python:v3 as stage-build
MAINTAINER Jumpserver Team <ibuler@qq.com>
ARG VERSION
ENV VERSION=$VERSION
WORKDIR /opt/jumpserver
ADD . .
RUN cd utils && bash -ixeu build.sh
FROM registry.fit2cloud.com/public/python:v3
ARG PIP_MIRROR=https://pypi.douban.com/simple
ENV PIP_MIRROR=$PIP_MIRROR
ARG MYSQL_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/mysql/yum/mysql57-community-el6/
ENV MYSQL_MIRROR=$MYSQL_MIRROR
MAINTAINER Jumpserver Team <ibuler@qq.com>
WORKDIR /opt/jumpserver
COPY ./requirements ./requirements
RUN useradd jumpserver
RUN yum -y install epel-release && \
echo -e "[mysql]\nname=mysql\nbaseurl=${MYSQL_MIRROR}\ngpgcheck=0\nenabled=1" > /etc/yum.repos.d/mysql.repo
RUN yum -y install $(cat requirements/rpm_requirements.txt)
RUN pip install --upgrade pip setuptools==49.6.0 wheel -i ${PIP_MIRROR} && \
pip config set global.index-url ${PIP_MIRROR}
RUN pip install $(grep 'jms' requirements/requirements.txt) -i https://pypi.org/simple
RUN pip install -r requirements/requirements.txt
COPY --from=stage-build /opt/jumpserver/release/jumpserver /opt/jumpserver
COPY ./requirements /tmp/requirements
RUN yum -y install epel-release && \
echo -e "[mysql]\nname=mysql\nbaseurl=https://mirrors.tuna.tsinghua.edu.cn/mysql/yum/mysql57-community-el6/\ngpgcheck=0\nenabled=1" > /etc/yum.repos.d/mysql.repo
RUN cd /tmp/requirements && yum -y install $(cat rpm_requirements.txt)
RUN cd /tmp/requirements && pip install --upgrade pip setuptools && pip install wheel && \
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt || pip install -r requirements.txt
RUN mkdir -p /root/.ssh/ && echo -e "Host *\n\tStrictHostKeyChecking no\n\tUserKnownHostsFile /dev/null" > /root/.ssh/config
COPY . /opt/jumpserver
RUN echo > config.yml
VOLUME /opt/jumpserver/data
VOLUME /opt/jumpserver/logs

View File

@@ -4,10 +4,6 @@
[![Django](https://img.shields.io/badge/django-2.2-brightgreen.svg?style=plastic)](https://www.djangoproject.com/)
[![Docker Pulls](https://img.shields.io/docker/pulls/jumpserver/jms_all.svg)](https://hub.docker.com/u/jumpserver)
|Developer Wanted|
|------------------|
|JumpServer 正在寻找开发者,一起为改变世界做些贡献吧,哪怕一点点,联系我 <ibuler@fit2cloud.com> |
JumpServer 是全球首款开源的堡垒机,使用 GNU GPL v2.0 开源协议,是符合 4A 规范的运维安全审计系统。
JumpServer 使用 Python / Django 为主进行开发,遵循 Web 2.0 规范,配备了业界领先的 Web Terminal 方案,交互界面美观、用户体验好。
@@ -20,13 +16,12 @@ JumpServer 采纳分布式架构,支持多机房跨区域部署,支持横向
## 特色优势
- 开源: 零门槛,线上快速获取和安装;
- 分布式: 轻松支持大规模并发访问;
- 开源: 零门槛,线上快速获取和安装, 修复版本视情况而定
, 修复版本视情况而定- 分布式: 轻松支持大规模并发访问;
- 无插件: 仅需浏览器,极致的 Web Terminal 使用体验;
- 多云支持: 一套系统,同时管理不同云上面的资产;
- 云端存储: 审计录像云端存储,永不丢失;
- 多租户: 一套系统,多个子公司和部门同时使用
- 多应用支持: 数据库Windows远程应用Kubernetes。
- 多租户: 一套系统,多个子公司和部门同时使用
## 版本说明
@@ -199,76 +194,13 @@ v2.1.0 是 v2.0.0 之后的功能版本。
<td>文件传输</td>
<td>可对文件的上传、下载记录进行审计</td>
</tr>
<tr>
<td rowspan="20">数据库审计<br>Database</td>
<td rowspan="2">连接方式</td>
<td>命令方式</td>
</tr>
<tr>
<td>Web UI方式 (X-PACK)</td>
</tr>
<tr>
<td rowspan="4">支持的数据库</td>
<td>MySQL</td>
</tr>
<tr>
<td>Oracle (X-PACK)</td>
</tr>
<tr>
<td>MariaDB (X-PACK)</td>
</tr>
<tr>
<td>PostgreSQL (X-PACK)</td>
</tr>
<tr>
<td rowspan="6">功能亮点</td>
<td>语法高亮</td>
</tr>
<tr>
<td>SQL格式化</td>
</tr>
<tr>
<td>支持快捷键</td>
</tr>
<tr>
<td>支持选中执行</td>
</tr>
<tr>
<td>SQL历史查询</td>
</tr>
<tr>
<td>支持页面创建 DB, TABLE</td>
</tr>
<tr>
<td rowspan="2">会话审计</td>
<td>命令记录</td>
</tr>
<tr>
<td>录像回放</td>
</tr>
</table>
## 快速开始
- [极速安装](https://docs.jumpserver.org/zh/master/install/setup_by_fast/)
- [完整文档](https://docs.jumpserver.org)
- [演示视频](https://www.bilibili.com/video/BV1ZV41127GB)
## 组件项目
- [Lina](https://github.com/jumpserver/lina) JumpServer Web UI 项目
- [Luna](https://github.com/jumpserver/luna) JumpServer Web Terminal 项目
- [Koko](https://github.com/jumpserver/koko) JumpServer 字符协议 Connector 项目,替代原来 Python 版本的 [Coco](https://github.com/jumpserver/coco)
- [Guacamole](https://github.com/jumpserver/docker-guacamole) JumpServer 图形协议 Connector 项目,依赖 [Apache Guacamole](https://guacamole.apache.org/)
## 致谢
- [Apache Guacamole](https://guacamole.apache.org/) Web页面连接 RDP, SSH, VNC协议设备JumpServer 图形化连接依赖
- [OmniDB](https://omnidb.org/) Web页面连接使用数据库JumpServer Web数据库依赖
## JumpServer 企业版
- [申请企业版试用](https://jinshuju.net/f/kyOYpi)
> 注:企业版支持离线安装,申请通过后会提供高速下载链接。
- [演示视频](https://jumpserver.oss-cn-hangzhou.aliyuncs.com/jms-media/%E3%80%90%E6%BC%94%E7%A4%BA%E8%A7%86%E9%A2%91%E3%80%91Jumpserver%20%E5%A0%A1%E5%9E%92%E6%9C%BA%20V1.5.0%20%E6%BC%94%E7%A4%BA%E8%A7%86%E9%A2%91%20-%20final.mp4)
## 案例研究

2
apps/.gitattributes vendored
View File

@@ -1,2 +0,0 @@
*.js linguist-language=python
*.html linguist-language=python

View File

@@ -1,5 +1,2 @@
from .application import *
from .mixin import *
from .remote_app import *
from .database_app import *
from .k8s_app import *

View File

@@ -1,20 +0,0 @@
# coding: utf-8
#
from orgs.mixins.api import OrgBulkModelViewSet
from .mixin import ApplicationAttrsSerializerViewMixin
from ..hands import IsOrgAdminOrAppUser
from .. import models, serializers
__all__ = [
'ApplicationViewSet',
]
class ApplicationViewSet(ApplicationAttrsSerializerViewMixin, OrgBulkModelViewSet):
model = models.Application
filter_fields = ('name', 'type', 'category')
search_fields = filter_fields
permission_classes = (IsOrgAdminOrAppUser,)
serializer_class = serializers.ApplicationSerializer

View File

@@ -1,20 +0,0 @@
# coding: utf-8
#
from orgs.mixins.api import OrgBulkModelViewSet
from .. import models
from .. import serializers
from ..hands import IsOrgAdminOrAppUser
__all__ = [
'K8sAppViewSet',
]
class K8sAppViewSet(OrgBulkModelViewSet):
model = models.K8sApp
filter_fields = ('name',)
search_fields = filter_fields
permission_classes = (IsOrgAdminOrAppUser,)
serializer_class = serializers.K8sAppSerializer

View File

@@ -1,95 +0,0 @@
from common.exceptions import JMSException
from .. import models
class ApplicationAttrsSerializerViewMixin:
def get_serializer_class(self):
serializer_class = super().get_serializer_class()
app_type = self.request.query_params.get('type')
app_category = self.request.query_params.get('category')
type_options = list(dict(models.Category.get_all_type_serializer_mapper()).keys())
category_options = list(dict(models.Category.get_category_serializer_mapper()).keys())
# ListAPIView 没有 action 属性
# 不使用method属性因为options请求时为method为post
action = getattr(self, 'action', 'list')
if app_type and app_type not in type_options:
raise JMSException(
'Invalid query parameter `type`, select from the following options: {}'
''.format(type_options)
)
if app_category and app_category not in category_options:
raise JMSException(
'Invalid query parameter `category`, select from the following options: {}'
''.format(category_options)
)
if action in [
'create', 'update', 'partial_update', 'bulk_update', 'partial_bulk_update'
] and not app_type:
# action: create / update
raise JMSException(
'The `{}` action must take the `type` query parameter'.format(action)
)
if app_type:
# action: create / update / list / retrieve / metadata
attrs_cls = models.Category.get_type_serializer_cls(app_type)
elif app_category:
# action: list / retrieve / metadata
attrs_cls = models.Category.get_category_serializer_cls(app_category)
else:
attrs_cls = models.Category.get_no_password_serializer_cls()
return type('ApplicationDynamicSerializer', (serializer_class,), {'attrs': attrs_cls()})
class SerializeApplicationToTreeNodeMixin:
@staticmethod
def _serialize_db(db):
return {
'id': db.id,
'name': db.name,
'title': db.name,
'pId': '',
'open': False,
'iconSkin': 'database',
'meta': {'type': 'database_app'}
}
@staticmethod
def _serialize_remote_app(remote_app):
return {
'id': remote_app.id,
'name': remote_app.name,
'title': remote_app.name,
'pId': '',
'open': False,
'isParent': False,
'iconSkin': 'chrome',
'meta': {'type': 'remote_app'}
}
@staticmethod
def _serialize_cloud(cloud):
return {
'id': cloud.id,
'name': cloud.name,
'title': cloud.name,
'pId': '',
'open': False,
'isParent': False,
'iconSkin': 'k8s',
'meta': {'type': 'k8s_app'}
}
def _serialize(self, application):
method_name = f'_serialize_{application.category}'
data = getattr(self, method_name)(application)
return data
def serialize_applications(self, applications):
data = [self._serialize(application) for application in applications]
return data

View File

@@ -3,9 +3,8 @@
from orgs.mixins.api import OrgBulkModelViewSet
from orgs.mixins import generics
from common.exceptions import JMSException
from ..hands import IsOrgAdmin, IsAppUser
from .. import models
from ..models import RemoteApp
from ..serializers import RemoteAppSerializer, RemoteAppConnectionInfoSerializer
@@ -15,7 +14,7 @@ __all__ = [
class RemoteAppViewSet(OrgBulkModelViewSet):
model = models.RemoteApp
model = RemoteApp
filter_fields = ('name', 'type', 'comment')
search_fields = filter_fields
permission_classes = (IsOrgAdmin,)
@@ -23,18 +22,6 @@ class RemoteAppViewSet(OrgBulkModelViewSet):
class RemoteAppConnectionInfoApi(generics.RetrieveAPIView):
model = models.Application
model = RemoteApp
permission_classes = (IsAppUser, )
serializer_class = RemoteAppConnectionInfoSerializer
@staticmethod
def check_category_allowed(obj):
if not obj.category_is_remote_app:
raise JMSException(
'The request instance(`{}`) is not of category `remote_app`'.format(obj.category)
)
def get_object(self):
obj = super().get_object()
self.check_category_allowed(obj)
return obj

View File

@@ -23,7 +23,6 @@ REMOTE_APP_TYPE_CHROME_FIELDS = [
REMOTE_APP_TYPE_MYSQL_WORKBENCH_FIELDS = [
{'name': 'mysql_workbench_ip'},
{'name': 'mysql_workbench_name'},
{'name': 'mysql_workbench_port'},
{'name': 'mysql_workbench_username'},
{'name': 'mysql_workbench_password', 'write_only': True}
]

View File

@@ -1,34 +0,0 @@
# Generated by Django 2.2.13 on 2020-08-07 07:13
from django.db import migrations, models
import uuid
class Migration(migrations.Migration):
dependencies = [
('applications', '0004_auto_20191218_1705'),
]
operations = [
migrations.CreateModel(
name='K8sApp',
fields=[
('created_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Created by')),
('updated_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Updated by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('org_id', models.CharField(blank=True, db_index=True, default='', max_length=36, verbose_name='Organization')),
('name', models.CharField(max_length=128, verbose_name='Name')),
('type', models.CharField(choices=[('k8s', 'Kubernetes')], default='k8s', max_length=128, verbose_name='Type')),
('cluster', models.CharField(max_length=1024, verbose_name='Cluster')),
('comment', models.TextField(blank=True, default='', max_length=128, verbose_name='Comment')),
],
options={
'verbose_name': 'KubernetesApp',
'ordering': ('name',),
'unique_together': {('org_id', 'name')},
},
),
]

View File

@@ -1,140 +0,0 @@
# Generated by Django 2.2.13 on 2020-10-19 12:01
from django.db import migrations, models
import django.db.models.deletion
import django_mysql.models
import uuid
CATEGORY_DB_LIST = ['mysql', 'oracle', 'postgresql', 'mariadb']
CATEGORY_REMOTE_LIST = ['chrome', 'mysql_workbench', 'vmware_client', 'custom']
CATEGORY_CLOUD_LIST = ['k8s']
CATEGORY_DB = 'db'
CATEGORY_REMOTE = 'remote_app'
CATEGORY_CLOUD = 'cloud'
CATEGORY_LIST = [CATEGORY_DB, CATEGORY_REMOTE, CATEGORY_CLOUD]
def get_application_category(old_app):
_type = old_app.type
if _type in CATEGORY_DB_LIST:
category = CATEGORY_DB
elif _type in CATEGORY_REMOTE_LIST:
category = CATEGORY_REMOTE
elif _type in CATEGORY_CLOUD_LIST:
category = CATEGORY_CLOUD
else:
category = None
return category
def common_to_application_json(old_app):
category = get_application_category(old_app)
date_updated = old_app.date_updated if hasattr(old_app, 'date_updated') else old_app.date_created
return {
'id': old_app.id,
'name': old_app.name,
'type': old_app.type,
'category': category,
'comment': old_app.comment,
'created_by': old_app.created_by,
'date_created': old_app.date_created,
'date_updated': date_updated,
'org_id': old_app.org_id
}
def db_to_application_json(database):
app_json = common_to_application_json(database)
app_json.update({
'attrs': {
'host': database.host,
'port': database.port,
'database': database.database
}
})
return app_json
def remote_to_application_json(remote):
app_json = common_to_application_json(remote)
attrs = {
'asset': str(remote.asset.id),
'path': remote.path,
}
attrs.update(remote.params)
app_json.update({
'attrs': attrs
})
return app_json
def k8s_to_application_json(k8s):
app_json = common_to_application_json(k8s)
app_json.update({
'attrs': {
'cluster': k8s.cluster
}
})
return app_json
def migrate_and_integrate_applications(apps, schema_editor):
db_alias = schema_editor.connection.alias
database_app_model = apps.get_model("applications", "DatabaseApp")
remote_app_model = apps.get_model("applications", "RemoteApp")
k8s_app_model = apps.get_model("applications", "K8sApp")
database_apps = database_app_model.objects.using(db_alias).all()
remote_apps = remote_app_model.objects.using(db_alias).all()
k8s_apps = k8s_app_model.objects.using(db_alias).all()
database_applications = [db_to_application_json(db_app) for db_app in database_apps]
remote_applications = [remote_to_application_json(remote_app) for remote_app in remote_apps]
k8s_applications = [k8s_to_application_json(k8s_app) for k8s_app in k8s_apps]
applications_json = database_applications + remote_applications + k8s_applications
application_model = apps.get_model("applications", "Application")
applications = [
application_model(**application_json)
for application_json in applications_json
if application_json['category'] in CATEGORY_LIST
]
for application in applications:
if application_model.objects.using(db_alias).filter(name=application.name).exists():
application.name = '{}-{}'.format(application.name, application.type)
application.save()
class Migration(migrations.Migration):
dependencies = [
('assets', '0057_fill_node_value_assets_amount_and_parent_key'),
('applications', '0005_k8sapp'),
]
operations = [
migrations.CreateModel(
name='Application',
fields=[
('org_id', models.CharField(blank=True, db_index=True, default='', max_length=36, verbose_name='Organization')),
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('created_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Created by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('name', models.CharField(max_length=128, verbose_name='Name')),
('category', models.CharField(choices=[('db', 'Database'), ('remote_app', 'Remote app'), ('cloud', 'Cloud')], max_length=16, verbose_name='Category')),
('type', models.CharField(choices=[('mysql', 'MySQL'), ('oracle', 'Oracle'), ('postgresql', 'PostgreSQL'), ('mariadb', 'MariaDB'), ('chrome', 'Chrome'), ('mysql_workbench', 'MySQL Workbench'), ('vmware_client', 'vSphere Client'), ('custom', 'Custom'), ('k8s', 'Kubernetes')], max_length=16, verbose_name='Type')),
('attrs', django_mysql.models.JSONField(default=dict)),
('comment', models.TextField(blank=True, default='', max_length=128, verbose_name='Comment')),
('domain', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='applications', to='assets.Domain', verbose_name='Domain')),
],
options={
'ordering': ('name',),
'unique_together': {('org_id', 'name')},
},
),
migrations.RunPython(migrate_and_integrate_applications),
]

View File

@@ -1,4 +1,2 @@
from .application import *
from .remote_app import *
from .database_app import *
from .k8s_app import *

View File

@@ -1,141 +0,0 @@
from itertools import chain
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django_mysql.models import JSONField, QuerySet
from orgs.mixins.models import OrgModelMixin
from common.mixins import CommonModelMixin
from common.db.models import ChoiceSet
class DBType(ChoiceSet):
mysql = 'mysql', 'MySQL'
oracle = 'oracle', 'Oracle'
pgsql = 'postgresql', 'PostgreSQL'
mariadb = 'mariadb', 'MariaDB'
@classmethod
def get_type_serializer_cls_mapper(cls):
from ..serializers import database_app
mapper = {
cls.mysql: database_app.MySQLAttrsSerializer,
cls.oracle: database_app.OracleAttrsSerializer,
cls.pgsql: database_app.PostgreAttrsSerializer,
cls.mariadb: database_app.MariaDBAttrsSerializer,
}
return mapper
class RemoteAppType(ChoiceSet):
chrome = 'chrome', 'Chrome'
mysql_workbench = 'mysql_workbench', 'MySQL Workbench'
vmware_client = 'vmware_client', 'vSphere Client'
custom = 'custom', _('Custom')
@classmethod
def get_type_serializer_cls_mapper(cls):
from ..serializers import remote_app
mapper = {
cls.chrome: remote_app.ChromeAttrsSerializer,
cls.mysql_workbench: remote_app.MySQLWorkbenchAttrsSerializer,
cls.vmware_client: remote_app.VMwareClientAttrsSerializer,
cls.custom: remote_app.CustomRemoteAppAttrsSeralizers,
}
return mapper
class CloudType(ChoiceSet):
k8s = 'k8s', 'Kubernetes'
@classmethod
def get_type_serializer_cls_mapper(cls):
from ..serializers import k8s_app
mapper = {
cls.k8s: k8s_app.K8sAttrsSerializer,
}
return mapper
class Category(ChoiceSet):
db = 'db', _('Database')
remote_app = 'remote_app', _('Remote app')
cloud = 'cloud', 'Cloud'
@classmethod
def get_category_type_mapper(cls):
return {
cls.db: DBType,
cls.remote_app: RemoteAppType,
cls.cloud: CloudType
}
@classmethod
def get_category_type_choices_mapper(cls):
return {
name: tp.choices
for name, tp in cls.get_category_type_mapper().items()
}
@classmethod
def get_type_choices(cls, category):
return cls.get_category_type_choices_mapper().get(category, [])
@classmethod
def get_all_type_choices(cls):
all_grouped_choices = tuple(cls.get_category_type_choices_mapper().values())
return tuple(chain(*all_grouped_choices))
@classmethod
def get_all_type_serializer_mapper(cls):
mapper = {}
for tp in cls.get_category_type_mapper().values():
mapper.update(tp.get_type_serializer_cls_mapper())
return mapper
@classmethod
def get_type_serializer_cls(cls, tp):
mapper = cls.get_all_type_serializer_mapper()
return mapper.get(tp, None)
@classmethod
def get_category_serializer_mapper(cls):
from ..serializers import remote_app, database_app, k8s_app
return {
cls.db: database_app.DBAttrsSerializer,
cls.remote_app: remote_app.RemoteAppAttrsSerializer,
cls.cloud: k8s_app.CloudAttrsSerializer,
}
@classmethod
def get_category_serializer_cls(cls, cg):
mapper = cls.get_category_serializer_mapper()
return mapper.get(cg, None)
@classmethod
def get_no_password_serializer_cls(cls):
from ..serializers import common
return common.NoPasswordSerializer
class Application(CommonModelMixin, OrgModelMixin):
name = models.CharField(max_length=128, verbose_name=_('Name'))
domain = models.ForeignKey('assets.Domain', null=True, blank=True, related_name='applications', verbose_name=_("Domain"), on_delete=models.SET_NULL)
category = models.CharField(max_length=16, choices=Category.choices, verbose_name=_('Category'))
type = models.CharField(max_length=16, choices=Category.get_all_type_choices(), verbose_name=_('Type'))
attrs = JSONField()
comment = models.TextField(
max_length=128, default='', blank=True, verbose_name=_('Comment')
)
class Meta:
unique_together = [('org_id', 'name')]
ordering = ('name',)
def __str__(self):
category_display = self.get_category_display()
type_display = self.get_type_display()
return f'{self.name}({type_display})[{category_display}]'
def category_is_remote_app(self):
return self.category == Category.remote_app

View File

@@ -1,27 +0,0 @@
from django.utils.translation import gettext_lazy as _
from common.db import models
from orgs.mixins.models import OrgModelMixin
class K8sApp(OrgModelMixin, models.JMSModel):
class TYPE(models.ChoiceSet):
K8S = 'k8s', _('Kubernetes')
name = models.CharField(max_length=128, verbose_name=_('Name'))
type = models.CharField(
default=TYPE.K8S, choices=TYPE.choices,
max_length=128, verbose_name=_('Type')
)
cluster = models.CharField(max_length=1024, verbose_name=_('Cluster'))
comment = models.TextField(
max_length=128, default='', blank=True, verbose_name=_('Comment')
)
def __str__(self):
return self.name
class Meta:
unique_together = [('org_id', 'name'), ]
verbose_name = _('KubernetesApp')
ordering = ('name', )

View File

@@ -1,5 +1,2 @@
from .application import *
from .remote_app import *
from .database_app import *
from .k8s_app import *
from .common import *

View File

@@ -1,44 +0,0 @@
# coding: utf-8
#
from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .. import models
__all__ = [
'ApplicationSerializer',
]
class ApplicationSerializer(BulkOrgResourceModelSerializer):
category_display = serializers.ReadOnlyField(source='get_category_display', label=_('Category'))
type_display = serializers.ReadOnlyField(source='get_type_display', label=_('Type'))
class Meta:
model = models.Application
fields = [
'id', 'name', 'category', 'category_display', 'type', 'type_display', 'attrs',
'domain', 'created_by', 'date_created', 'date_updated', 'comment'
]
read_only_fields = [
'created_by', 'date_created', 'date_updated', 'get_type_display',
]
def create(self, validated_data):
attrs = validated_data.pop('attrs', {})
instance = super().create(validated_data)
instance.attrs = attrs
instance.save()
return instance
def update(self, instance, validated_data):
new_attrs = validated_data.pop('attrs', {})
instance = super().update(instance, validated_data)
attrs = instance.attrs
attrs.update(new_attrs)
instance.attrs = attrs
instance.save()
return instance

View File

@@ -1,11 +0,0 @@
from rest_framework import serializers
class NoPasswordSerializer(serializers.JSONField):
def to_representation(self, value):
new_value = {}
for k, v in value.items():
if 'password' not in k:
new_value[k] = v
return new_value

View File

@@ -1,36 +1,14 @@
# coding: utf-8
#
from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from common.serializers import AdaptedBulkListSerializer
from .. import models
class DBAttrsSerializer(serializers.Serializer):
host = serializers.CharField(max_length=128, label=_('Host'))
port = serializers.IntegerField(label=_('Port'))
database = serializers.CharField(
max_length=128, required=False, allow_blank=True, allow_null=True, label=_('Database')
)
class MySQLAttrsSerializer(DBAttrsSerializer):
port = serializers.IntegerField(default=3306, label=_('Port'))
class PostgreAttrsSerializer(DBAttrsSerializer):
port = serializers.IntegerField(default=5432, label=_('Port'))
class OracleAttrsSerializer(DBAttrsSerializer):
port = serializers.IntegerField(default=1521, label=_('Port'))
class MariaDBAttrsSerializer(MySQLAttrsSerializer):
pass
__all__ = [
'DatabaseAppSerializer',
]
class DatabaseAppSerializer(BulkOrgResourceModelSerializer):
@@ -46,6 +24,3 @@ class DatabaseAppSerializer(BulkOrgResourceModelSerializer):
'created_by', 'date_created', 'date_updated'
'get_type_display',
]
extra_kwargs = {
'get_type_display': {'label': _('Type for display')},
}

View File

@@ -1,27 +0,0 @@
from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .. import models
class CloudAttrsSerializer(serializers.Serializer):
cluster = serializers.CharField(max_length=1024, label=_('Cluster'))
class K8sAttrsSerializer(CloudAttrsSerializer):
pass
class K8sAppSerializer(BulkOrgResourceModelSerializer):
type_display = serializers.CharField(source='get_type_display', read_only=True, label=_('Type for display'))
class Meta:
model = models.K8sApp
fields = [
'id', 'name', 'type', 'type_display', 'comment', 'created_by',
'date_created', 'date_updated', 'cluster'
]
read_only_fields = [
'id', 'created_by', 'date_created', 'date_updated',
]

View File

@@ -2,138 +2,21 @@
#
import copy
from django.utils.translation import ugettext_lazy as _
from django.core.exceptions import ObjectDoesNotExist
from rest_framework import serializers
from common.serializers import AdaptedBulkListSerializer
from common.fields.serializer import CustomMetaDictField
from common.utils import get_logger
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from assets.models import Asset
from .. import const
from ..models import RemoteApp, Category, Application
logger = get_logger(__file__)
from ..models import RemoteApp
class CharPrimaryKeyRelatedField(serializers.PrimaryKeyRelatedField):
def to_internal_value(self, data):
instance = super().to_internal_value(data)
return str(instance.id)
def to_representation(self, value):
# value is instance.id
if self.pk_field is not None:
return self.pk_field.to_representation(value)
return value
__all__ = [
'RemoteAppSerializer', 'RemoteAppConnectionInfoSerializer',
]
class RemoteAppAttrsSerializer(serializers.Serializer):
asset_info = serializers.SerializerMethodField()
asset = CharPrimaryKeyRelatedField(queryset=Asset.objects, required=False, label=_("Asset"))
path = serializers.CharField(max_length=128, label=_('Application path'))
@staticmethod
def get_asset_info(obj):
asset_info = {}
asset_id = obj.get('asset')
if not asset_id:
return asset_info
try:
asset = Asset.objects.get(id=asset_id)
asset_info.update({
'id': str(asset.id),
'hostname': asset.hostname
})
except ObjectDoesNotExist as e:
logger.error(e)
return asset_info
class ChromeAttrsSerializer(RemoteAppAttrsSerializer):
REMOTE_APP_PATH = 'C:\Program Files (x86)\Google\Chrome\Application\chrome.exe'
path = serializers.CharField(max_length=128, label=_('Application path'), default=REMOTE_APP_PATH)
chrome_target = serializers.CharField(max_length=128, allow_blank=True, required=False, label=_('Target URL'))
chrome_username = serializers.CharField(max_length=128, allow_blank=True, required=False, label=_('Username'))
chrome_password = serializers.CharField(max_length=128, allow_blank=True, required=False, write_only=True, label=_('Password'))
class MySQLWorkbenchAttrsSerializer(RemoteAppAttrsSerializer):
REMOTE_APP_PATH = 'C:\Program Files\MySQL\MySQL Workbench 8.0 CE\MySQLWorkbench.exe'
path = serializers.CharField(max_length=128, label=_('Application path'), default=REMOTE_APP_PATH)
mysql_workbench_ip = serializers.CharField(max_length=128, allow_blank=True, required=False, label=_('IP'))
mysql_workbench_port = serializers.IntegerField(required=False, label=_('Port'))
mysql_workbench_name = serializers.CharField(max_length=128, allow_blank=True, required=False, label=_('Database'))
mysql_workbench_username = serializers.CharField(max_length=128, allow_blank=True, required=False, label=_('Username'))
mysql_workbench_password = serializers.CharField(max_length=128, allow_blank=True, required=False, write_only=True, label=_('Password'))
class VMwareClientAttrsSerializer(RemoteAppAttrsSerializer):
REMOTE_APP_PATH = 'C:\Program Files (x86)\VMware\Infrastructure\Virtual Infrastructure Client\Launcher\VpxClient.exe'
path = serializers.CharField(max_length=128, label=_('Application path'), default=REMOTE_APP_PATH)
vmware_target = serializers.CharField(max_length=128, allow_blank=True, required=False, label=_('Target URL'))
vmware_username = serializers.CharField(max_length=128, allow_blank=True, required=False, label=_('Username'))
vmware_password = serializers.CharField(max_length=128, allow_blank=True, required=False, write_only=True, label=_('Password'))
class CustomRemoteAppAttrsSeralizers(RemoteAppAttrsSerializer):
custom_cmdline = serializers.CharField(max_length=128, allow_blank=True, required=False, label=_('Operating parameter'))
custom_target = serializers.CharField(max_length=128, allow_blank=True, required=False, label=_('Target url'))
custom_username = serializers.CharField(max_length=128, allow_blank=True, required=False, label=_('Username'))
custom_password = serializers.CharField(max_length=128, allow_blank=True, required=False, write_only=True, label=_('Password'))
class RemoteAppConnectionInfoSerializer(serializers.ModelSerializer):
parameter_remote_app = serializers.SerializerMethodField()
asset = serializers.SerializerMethodField()
class Meta:
model = Application
fields = [
'id', 'name', 'asset', 'parameter_remote_app',
]
read_only_fields = ['parameter_remote_app']
@staticmethod
def get_parameters(obj):
"""
返回Guacamole需要的RemoteApp配置参数信息中的parameters参数
"""
serializer_cls = Category.get_type_serializer_cls(obj.type)
fields = serializer_cls().get_fields()
fields.pop('asset', None)
fields_name = list(fields.keys())
attrs = obj.attrs
_parameters = list()
_parameters.append(obj.type)
for field_name in list(fields_name):
value = attrs.get(field_name, None)
if not value:
continue
if field_name == 'path':
value = '\"%s\"' % value
_parameters.append(str(value))
_parameters = ' '.join(_parameters)
return _parameters
def get_parameter_remote_app(self, obj):
parameters = self.get_parameters(obj)
parameter = {
'program': const.REMOTE_APP_BOOT_PROGRAM_NAME,
'working_directory': '',
'parameters': parameters,
}
return parameter
@staticmethod
def get_asset(obj):
return obj.attrs.get('asset')
# TODO: DELETE
class RemoteAppParamsDictField(CustomMetaDictField):
type_fields_map = const.REMOTE_APP_TYPE_FIELDS_MAP
default_type = const.REMOTE_APP_TYPE_CHROME
@@ -141,9 +24,8 @@ class RemoteAppParamsDictField(CustomMetaDictField):
convert_key_to_upper = False
# TODO: DELETE
class RemoteAppSerializer(BulkOrgResourceModelSerializer):
params = RemoteAppParamsDictField(label=_('Parameters'))
params = RemoteAppParamsDictField()
type_fields_map = const.REMOTE_APP_TYPE_FIELDS_MAP
class Meta:
@@ -157,10 +39,6 @@ class RemoteAppSerializer(BulkOrgResourceModelSerializer):
'created_by', 'date_created', 'asset_info',
'get_type_display'
]
extra_kwargs = {
'asset_info': {'label': _('Asset info')},
'get_type_display': {'label': _('Type for display')},
}
def process_params(self, instance, validated_data):
new_params = copy.deepcopy(validated_data.get('params', {}))
@@ -188,3 +66,21 @@ class RemoteAppSerializer(BulkOrgResourceModelSerializer):
return super().update(instance, validated_data)
class RemoteAppConnectionInfoSerializer(serializers.ModelSerializer):
parameter_remote_app = serializers.SerializerMethodField()
class Meta:
model = RemoteApp
fields = [
'id', 'name', 'asset', 'parameter_remote_app',
]
read_only_fields = ['parameter_remote_app']
@staticmethod
def get_parameter_remote_app(obj):
parameter = {
'program': const.REMOTE_APP_BOOT_PROGRAM_NAME,
'working_directory': '',
'parameters': obj.parameters,
}
return parameter

View File

@@ -10,10 +10,8 @@ from .. import api
app_name = 'applications'
router = BulkRouter()
router.register(r'applications', api.ApplicationViewSet, 'application')
router.register(r'remote-apps', api.RemoteAppViewSet, 'remote-app')
router.register(r'database-apps', api.DatabaseAppViewSet, 'database-app')
router.register(r'k8s-apps', api.K8sAppViewSet, 'k8s-app')
urlpatterns = [
path('remote-apps/<uuid:pk>/connection-info/', api.RemoteAppConnectionInfoApi.as_view(), name='remote-app-connection-info'),

View File

@@ -1,4 +1,3 @@
from .mixin import *
from .admin_user import *
from .asset import *
from .label import *

View File

@@ -94,6 +94,7 @@ class AdminUserAssetsListView(generics.ListAPIView):
permission_classes = (IsOrgAdmin,)
serializer_class = serializers.AssetSimpleSerializer
filter_fields = ("hostname", "ip")
http_method_names = ['get']
search_fields = filter_fields
def get_object(self):

View File

@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
#
from assets.api import FilterAssetByNodeMixin
from rest_framework.viewsets import ModelViewSet
from rest_framework.generics import RetrieveAPIView
from django.shortcuts import get_object_or_404
@@ -14,7 +14,7 @@ from .. import serializers
from ..tasks import (
update_asset_hardware_info_manual, test_asset_connectivity_manual
)
from ..filters import FilterAssetByNodeFilterBackend, LabelFilterBackend, IpInFilterBackend
from ..filters import AssetByNodeFilterBackend, LabelFilterBackend
logger = get_logger(__file__)
@@ -25,7 +25,7 @@ __all__ = [
]
class AssetViewSet(FilterAssetByNodeMixin, OrgBulkModelViewSet):
class AssetViewSet(OrgBulkModelViewSet):
"""
API endpoint that allows Asset to be viewed or edited.
"""
@@ -41,7 +41,7 @@ class AssetViewSet(FilterAssetByNodeMixin, OrgBulkModelViewSet):
'display': serializers.AssetDisplaySerializer,
}
permission_classes = (IsOrgAdminOrAppUser,)
extra_filter_backends = [FilterAssetByNodeFilterBackend, LabelFilterBackend, IpInFilterBackend]
extra_filter_backends = [AssetByNodeFilterBackend, LabelFilterBackend]
def set_assets_node(self, assets):
if not isinstance(assets, list):

View File

@@ -1,9 +1,7 @@
# ~*~ coding: utf-8 ~*~
from django.views.generic.detail import SingleObjectMixin
from django.utils.translation import ugettext as _
from rest_framework.views import APIView, Response
from rest_framework.serializers import ValidationError
from django.views.generic.detail import SingleObjectMixin
from common.utils import get_logger
from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser
@@ -44,10 +42,6 @@ class GatewayTestConnectionApi(SingleObjectMixin, APIView):
def post(self, request, *args, **kwargs):
self.object = self.get_object(Gateway.objects.all())
local_port = self.request.data.get('port') or self.object.port
try:
local_port = int(local_port)
except ValueError:
raise ValidationError({'port': _('Number required')})
ok, e = self.object.test_connective(local_port=local_port)
if ok:
return Response("ok")

View File

@@ -1,89 +0,0 @@
from typing import List
from assets.models import Node, Asset
from assets.pagination import AssetLimitOffsetPagination
from common.utils import lazyproperty, dict_get_any, is_uuid, get_object_or_none
from assets.utils import get_node, is_query_node_all_assets
class SerializeToTreeNodeMixin:
permission_classes = ()
def serialize_nodes(self, nodes: List[Node], with_asset_amount=False):
if with_asset_amount:
def _name(node: Node):
return '{} ({})'.format(node.value, node.assets_amount)
else:
def _name(node: Node):
return node.value
data = [
{
'id': node.key,
'name': _name(node),
'title': _name(node),
'pId': node.parent_key,
'isParent': True,
'open': node.is_org_root(),
'meta': {
'node': {
"id": node.id,
"key": node.key,
"value": node.value,
},
'type': 'node'
}
}
for node in nodes
]
return data
def get_platform(self, asset: Asset):
default = 'file'
icon = {'windows', 'linux'}
platform = asset.platform_base.lower()
if platform in icon:
return platform
return default
def serialize_assets(self, assets, node_key=None):
if node_key is None:
get_pid = lambda asset: getattr(asset, 'parent_key', '')
else:
get_pid = lambda asset: node_key
data = [
{
'id': str(asset.id),
'name': asset.hostname,
'title': asset.ip,
'pId': get_pid(asset),
'isParent': False,
'open': False,
'iconSkin': self.get_platform(asset),
'chkDisabled': not asset.is_active,
'meta': {
'type': 'asset',
'asset': {
'id': asset.id,
'hostname': asset.hostname,
'ip': asset.ip,
'protocols': asset.protocols_as_list,
'platform': asset.platform_base,
},
}
}
for asset in assets
]
return data
class FilterAssetByNodeMixin:
pagination_class = AssetLimitOffsetPagination
@lazyproperty
def is_query_node_all_assets(self):
return is_query_node_all_assets(self.request)
@lazyproperty
def node(self):
return get_node(self.request)

View File

@@ -1,24 +1,16 @@
# ~*~ coding: utf-8 ~*~
from functools import partial
from collections import namedtuple, defaultdict
from collections import namedtuple
from rest_framework import status
from rest_framework.serializers import ValidationError
from rest_framework.response import Response
from django.utils.translation import ugettext_lazy as _
from django.shortcuts import get_object_or_404, Http404
from django.utils.decorators import method_decorator
from django.db.models.signals import m2m_changed
from common.exceptions import SomeoneIsDoingThis
from common.const.signals import PRE_REMOVE, POST_REMOVE
from assets.models import Asset
from common.utils import get_logger, get_object_or_none
from common.tree import TreeNodeSerializer
from common.const.distributed_lock_key import UPDATE_NODE_TREE_LOCK_KEY
from orgs.mixins.api import OrgModelViewSet
from orgs.mixins import generics
from orgs.lock import org_level_transaction_lock
from ..hands import IsOrgAdmin
from ..models import Node
from ..tasks import (
@@ -26,13 +18,12 @@ from ..tasks import (
test_node_assets_connectivity_manual,
)
from .. import serializers
from .mixin import SerializeToTreeNodeMixin
logger = get_logger(__file__)
__all__ = [
'NodeViewSet', 'NodeChildrenApi', 'NodeAssetsApi',
'NodeAddAssetsApi', 'NodeRemoveAssetsApi', 'MoveAssetsToNodeApi',
'NodeAddAssetsApi', 'NodeRemoveAssetsApi', 'NodeReplaceAssetsApi',
'NodeAddChildrenApi', 'NodeListAsTreeApi',
'NodeChildrenAsTreeApi',
'NodeTaskCreateApi',
@@ -145,7 +136,7 @@ class NodeChildrenApi(generics.ListCreateAPIView):
return queryset
class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
class NodeChildrenAsTreeApi(NodeChildrenApi):
"""
节点子节点作为树返回,
[
@@ -159,23 +150,31 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
"""
model = Node
serializer_class = TreeNodeSerializer
http_method_names = ['get']
def list(self, request, *args, **kwargs):
nodes = self.get_queryset().order_by('value')
nodes = self.serialize_nodes(nodes, with_asset_amount=True)
assets = self.get_assets()
data = [*nodes, *assets]
return Response(data=data)
def get_queryset(self):
queryset = super().get_queryset()
queryset = [node.as_tree_node() for node in queryset]
queryset = self.add_assets_if_need(queryset)
queryset = sorted(queryset)
return queryset
def get_assets(self):
def add_assets_if_need(self, queryset):
include_assets = self.request.query_params.get('assets', '0') == '1'
if not include_assets:
return []
return queryset
assets = self.instance.get_assets().only(
"id", "hostname", "ip", "os",
"org_id", "protocols",
)
return self.serialize_assets(assets, self.instance.key)
for asset in assets:
queryset.append(asset.as_tree_node(self.instance))
return queryset
def check_need_refresh_nodes(self):
if self.request.query_params.get('refresh', '0') == '1':
Node.refresh_nodes()
class NodeAssetsApi(generics.ListAPIView):
@@ -209,8 +208,6 @@ class NodeAddChildrenApi(generics.UpdateAPIView):
return Response("OK")
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class NodeAddAssetsApi(generics.UpdateAPIView):
model = Node
serializer_class = serializers.NodeAssetsSerializer
@@ -223,8 +220,6 @@ class NodeAddAssetsApi(generics.UpdateAPIView):
instance.assets.add(*tuple(assets))
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class NodeRemoveAssetsApi(generics.UpdateAPIView):
model = Node
serializer_class = serializers.NodeAssetsSerializer
@@ -233,17 +228,15 @@ class NodeRemoveAssetsApi(generics.UpdateAPIView):
def perform_update(self, serializer):
assets = serializer.validated_data.get('assets')
node = self.get_object()
node.assets.remove(*assets)
# 把孤儿资产添加到 root 节点
orphan_assets = Asset.objects.filter(id__in=[a.id for a in assets], nodes__isnull=True).distinct()
Node.org_root().assets.add(*orphan_assets)
instance = self.get_object()
if instance != Node.org_root():
instance.assets.remove(*tuple(assets))
else:
assets = [asset for asset in assets if asset.nodes.count() > 1]
instance.assets.remove(*tuple(assets))
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class MoveAssetsToNodeApi(generics.UpdateAPIView):
class NodeReplaceAssetsApi(generics.UpdateAPIView):
model = Node
serializer_class = serializers.NodeAssetsSerializer
permission_classes = (IsOrgAdmin,)
@@ -251,39 +244,9 @@ class MoveAssetsToNodeApi(generics.UpdateAPIView):
def perform_update(self, serializer):
assets = serializer.validated_data.get('assets')
node = self.get_object()
self.remove_old_nodes(assets)
node.assets.add(*assets)
def remove_old_nodes(self, assets):
m2m_model = Asset.nodes.through
# 查询资产与节点关系表,查出要移动资产与节点的所有关系
relates = m2m_model.objects.filter(asset__in=assets).values_list('asset_id', 'node_id')
if relates:
# 对关系以资产进行分组,用来发 `reverse=False` 信号
asset_nodes_mapper = defaultdict(set)
for asset_id, node_id in relates:
asset_nodes_mapper[asset_id].add(node_id)
# 组建一个资产 id -> Asset 的 mapper
asset_mapper = {asset.id: asset for asset in assets}
# 创建删除关系信号发送函数
senders = []
for asset_id, node_id_set in asset_nodes_mapper.items():
senders.append(partial(m2m_changed.send, sender=m2m_model, instance=asset_mapper[asset_id],
reverse=False, model=Node, pk_set=node_id_set))
# 发送 pre 信号
[sender(action=PRE_REMOVE) for sender in senders]
num = len(relates)
asset_ids, node_ids = zip(*relates)
# 删除之前的关系
rows, _i = m2m_model.objects.filter(asset_id__in=asset_ids, node_id__in=node_ids).delete()
if rows != num:
raise SomeoneIsDoingThis
# 发送 post 信号
[sender(action=POST_REMOVE) for sender in senders]
instance = self.get_object()
for asset in assets:
asset.nodes.set([instance])
class NodeTaskCreateApi(generics.CreateAPIView):
@@ -304,6 +267,7 @@ class NodeTaskCreateApi(generics.CreateAPIView):
@staticmethod
def refresh_nodes_cache():
Node.refresh_nodes()
Task = namedtuple('Task', ['id'])
task = Task(id="0")
return task

View File

@@ -19,7 +19,7 @@ from ..tasks import (
logger = get_logger(__file__)
__all__ = [
'SystemUserViewSet', 'SystemUserAuthInfoApi', 'SystemUserAssetAuthInfoApi',
'SystemUserCommandFilterRuleListApi', 'SystemUserTaskApi', 'SystemUserAssetsListView',
'SystemUserCommandFilterRuleListApi', 'SystemUserTaskApi',
]
@@ -125,18 +125,3 @@ class SystemUserCommandFilterRuleListApi(generics.ListAPIView):
pk = self.kwargs.get('pk', None)
system_user = get_object_or_404(SystemUser, pk=pk)
return system_user.cmd_filter_rules
class SystemUserAssetsListView(generics.ListAPIView):
permission_classes = (IsOrgAdmin,)
serializer_class = serializers.AssetSimpleSerializer
filter_fields = ("hostname", "ip")
search_fields = filter_fields
def get_object(self):
pk = self.kwargs.get('pk')
return get_object_or_404(SystemUser, pk=pk)
def get_queryset(self):
system_user = self.get_object()
return system_user.get_all_assets()

View File

@@ -95,7 +95,7 @@ class SystemUserNodeRelationViewSet(BaseRelationViewSet):
'id', 'node', 'systemuser',
]
search_fields = [
"node__value", "systemuser__name", "systemuser__username"
"node__value", "systemuser__name", "systemuser_username"
]
def get_objects_attr(self):

View File

@@ -1,6 +0,0 @@
from rest_framework import status
from common.exceptions import JMSException
class NodeIsBeingUpdatedByOthers(JMSException):
status_code = status.HTTP_409_CONFLICT

View File

@@ -1,12 +1,12 @@
# -*- coding: utf-8 -*-
#
from rest_framework.compat import coreapi, coreschema
import coreapi
from rest_framework import filters
from django.db.models import Q
from .models import Label
from assets.utils import is_query_node_all_assets, get_node
from common.utils import dict_get_any, is_uuid, get_object_or_none
from .models import Node, Label
class AssetByNodeFilterBackend(filters.BaseFilterBackend):
@@ -21,54 +21,47 @@ class AssetByNodeFilterBackend(filters.BaseFilterBackend):
for field in self.fields
]
def filter_node_related_all(self, queryset, node):
return queryset.filter(
Q(nodes__key__istartswith=f'{node.key}:') |
Q(nodes__key=node.key)
).distinct()
@staticmethod
def is_query_all(request):
query_all_arg = request.query_params.get('all')
show_current_asset_arg = request.query_params.get('show_current_asset')
def filter_node_related_direct(self, queryset, node):
return queryset.filter(nodes__key=node.key).distinct()
query_all = query_all_arg == '1'
if show_current_asset_arg is not None:
query_all = show_current_asset_arg != '1'
return query_all
@staticmethod
def get_query_node(request):
node_id = dict_get_any(request.query_params, ['node', 'node_id'])
if not node_id:
return None, False
if is_uuid(node_id):
node = get_object_or_none(Node, id=node_id)
else:
node = get_object_or_none(Node, key=node_id)
return node, True
@staticmethod
def perform_query(pattern, queryset):
return queryset.filter(nodes__key__regex=pattern).distinct()
def filter_queryset(self, request, queryset, view):
node = get_node(request)
if node is None:
node, has_query_arg = self.get_query_node(request)
if not has_query_arg:
return queryset
query_all = is_query_node_all_assets(request)
if query_all:
return self.filter_node_related_all(queryset, node)
else:
return self.filter_node_related_direct(queryset, node)
class FilterAssetByNodeFilterBackend(filters.BaseFilterBackend):
"""
需要与 `assets.api.mixin.FilterAssetByNodeMixin` 配合使用
"""
fields = ['node', 'all']
def get_schema_fields(self, view):
return [
coreapi.Field(
name=field, location='query', required=False,
type='string', example='', description='', schema=None,
)
for field in self.fields
]
def filter_queryset(self, request, queryset, view):
node = view.node
if node is None:
return queryset
query_all = view.is_query_node_all_assets
query_all = self.is_query_all(request)
if query_all:
return queryset.filter(
Q(nodes__key__istartswith=f'{node.key}:') |
Q(nodes__key=node.key)
).distinct()
pattern = node.get_all_children_pattern(with_self=True)
else:
return queryset.filter(nodes__key=node.key).distinct()
# pattern = node.get_children_key_pattern(with_self=True)
# 只显示当前节点下资产
pattern = r"^{}$".format(node.key)
return self.perform_query(pattern, queryset)
class LabelFilterBackend(filters.BaseFilterBackend):
@@ -120,32 +113,7 @@ class LabelFilterBackend(filters.BaseFilterBackend):
class AssetRelatedByNodeFilterBackend(AssetByNodeFilterBackend):
def filter_node_related_all(self, queryset, node):
return queryset.filter(
Q(asset__nodes__key__istartswith=f'{node.key}:') |
Q(asset__nodes__key=node.key)
).distinct()
@staticmethod
def perform_query(pattern, queryset):
return queryset.filter(asset__nodes__key__regex=pattern).distinct()
def filter_node_related_direct(self, queryset, node):
return queryset.filter(asset__nodes__key=node.key).distinct()
class IpInFilterBackend(filters.BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
ips = request.query_params.get('ips')
if not ips:
return queryset
ip_list = [i.strip() for i in ips.split(',')]
queryset = queryset.filter(ip__in=ip_list)
return queryset
def get_schema_fields(self, view):
return [
coreapi.Field(
name='ips', location='query', required=False, type='string',
schema=coreschema.String(
title='ips',
description='ip in filter'
)
)
]

View File

@@ -1,18 +0,0 @@
# Generated by Django 2.2.10 on 2020-07-11 09:40
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0049_systemuser_sftp_root'),
]
operations = [
migrations.AlterField(
model_name='asset',
name='created_by',
field=models.CharField(blank=True, max_length=128, null=True, verbose_name='Created by'),
),
]

View File

@@ -1,22 +0,0 @@
# Generated by Django 2.2.10 on 2020-07-13 03:43
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0050_auto_20200711_1740'),
]
operations = [
migrations.AlterField(
model_name='domain',
name='name',
field=models.CharField(max_length=128, verbose_name='Name'),
),
migrations.AlterUniqueTogether(
name='domain',
unique_together={('org_id', 'name')},
),
]

View File

@@ -1,22 +0,0 @@
# Generated by Django 2.2.10 on 2020-07-15 07:35
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0051_auto_20200713_1143'),
]
operations = [
migrations.AlterField(
model_name='commandfilter',
name='name',
field=models.CharField(max_length=64, verbose_name='Name'),
),
migrations.AlterUniqueTogether(
name='commandfilter',
unique_together={('org_id', 'name')},
),
]

View File

@@ -1,34 +0,0 @@
# Generated by Django 2.2.10 on 2020-07-23 04:32
import django.core.validators
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0052_auto_20200715_1535'),
]
operations = [
migrations.AlterField(
model_name='adminuser',
name='username',
field=models.CharField(blank=True, db_index=True, max_length=128, validators=[django.core.validators.RegexValidator('^[0-9a-zA-Z_@\\-\\.]*$', 'Special char not allowed')], verbose_name='Username'),
),
migrations.AlterField(
model_name='authbook',
name='username',
field=models.CharField(blank=True, db_index=True, max_length=128, validators=[django.core.validators.RegexValidator('^[0-9a-zA-Z_@\\-\\.]*$', 'Special char not allowed')], verbose_name='Username'),
),
migrations.AlterField(
model_name='gateway',
name='username',
field=models.CharField(blank=True, db_index=True, max_length=128, validators=[django.core.validators.RegexValidator('^[0-9a-zA-Z_@\\-\\.]*$', 'Special char not allowed')], verbose_name='Username'),
),
migrations.AlterField(
model_name='systemuser',
name='username',
field=models.CharField(blank=True, db_index=True, max_length=128, validators=[django.core.validators.RegexValidator('^[0-9a-zA-Z_@\\-\\.]*$', 'Special char not allowed')], verbose_name='Username'),
),
]

View File

@@ -1,23 +0,0 @@
# Generated by Django 2.2.13 on 2020-08-07 02:32
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0053_auto_20200723_1232'),
]
operations = [
migrations.AddField(
model_name='systemuser',
name='token',
field=models.TextField(default='', verbose_name='Token'),
),
migrations.AlterField(
model_name='systemuser',
name='protocol',
field=models.CharField(choices=[('ssh', 'ssh'), ('rdp', 'rdp'), ('telnet', 'telnet'), ('vnc', 'vnc'), ('mysql', 'mysql'), ('k8s', 'k8s')], default='ssh', max_length=16, verbose_name='Protocol'),
),
]

View File

@@ -1,23 +0,0 @@
# Generated by Django 2.2.13 on 2020-08-11 10:45
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0054_auto_20200807_1032'),
]
operations = [
migrations.AddField(
model_name='systemuser',
name='home',
field=models.CharField(blank=True, default='', max_length=4096, verbose_name='Home'),
),
migrations.AddField(
model_name='systemuser',
name='system_groups',
field=models.CharField(blank=True, default='', max_length=4096, verbose_name='System groups'),
),
]

View File

@@ -1,23 +0,0 @@
# Generated by Django 2.2.13 on 2020-09-04 09:51
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0055_auto_20200811_1845'),
]
operations = [
migrations.AddField(
model_name='node',
name='assets_amount',
field=models.IntegerField(default=0),
),
migrations.AddField(
model_name='node',
name='parent_key',
field=models.CharField(db_index=True, default='', max_length=64, verbose_name='Parent key'),
),
]

View File

@@ -1,38 +0,0 @@
# Generated by Django 2.2.13 on 2020-08-21 08:20
from django.db import migrations
from django.db.models import Q
def fill_node_value(apps, schema_editor):
Node = apps.get_model('assets', 'Node')
Asset = apps.get_model('assets', 'Asset')
node_queryset = Node.objects.all()
node_amount = node_queryset.count()
width = len(str(node_amount))
print('\n')
for i, node in enumerate(node_queryset):
print(f'\t{i+1:0>{width}}/{node_amount} compute node[{node.key}]`s assets_amount ...')
assets_amount = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes=node)
).distinct().count()
key = node.key
try:
parent_key = key[:key.rindex(':')]
except ValueError:
parent_key = ''
node.assets_amount = assets_amount
node.parent_key = parent_key
node.save()
print(' ' + '.'*65, end='')
class Migration(migrations.Migration):
dependencies = [
('assets', '0056_auto_20200904_1751'),
]
operations = [
migrations.RunPython(fill_node_value)
]

View File

@@ -1,27 +0,0 @@
# Generated by Django 2.2.13 on 2020-10-23 03:15
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0057_fill_node_value_assets_amount_and_parent_key'),
]
operations = [
migrations.AlterModelOptions(
name='asset',
options={'ordering': ['-date_created'], 'verbose_name': 'Asset'},
),
migrations.AlterField(
model_name='asset',
name='comment',
field=models.TextField(blank=True, default='', verbose_name='Comment'),
),
migrations.AlterField(
model_name='commandfilterrule',
name='content',
field=models.TextField(help_text='One line one command', verbose_name='Content'),
),
]

View File

@@ -1,28 +0,0 @@
# Generated by Django 2.2.13 on 2020-10-27 11:05
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0058_auto_20201023_1115'),
]
operations = [
migrations.AlterField(
model_name='systemuser',
name='protocol',
field=models.CharField(choices=[('ssh', 'ssh'), ('rdp', 'rdp'), ('telnet', 'telnet'), ('vnc', 'vnc'), ('mysql', 'mysql'), ('oracle', 'oracle'), ('mariadb', 'mariadb'), ('postgresql', 'postgresql'), ('k8s', 'k8s')], default='ssh', max_length=16, verbose_name='Protocol'),
),
migrations.AddField(
model_name='systemuser',
name='ad_domain',
field=models.CharField(default='', max_length=256),
),
migrations.AlterField(
model_name='gateway',
name='ip',
field=models.CharField(db_index=True, max_length=128, verbose_name='IP'),
),
]

View File

@@ -1,58 +0,0 @@
# Generated by Django 2.2.13 on 2020-10-26 11:31
from django.db import migrations, models
def get_node_ancestor_keys(key, with_self=False):
parent_keys = []
key_list = key.split(":")
if not with_self:
key_list.pop()
for i in range(len(key_list)):
parent_keys.append(":".join(key_list))
key_list.pop()
return parent_keys
def migrate_nodes_value_with_slash(apps, schema_editor):
model = apps.get_model("assets", "Node")
db_alias = schema_editor.connection.alias
nodes = model.objects.using(db_alias).filter(value__contains='/')
print('')
print("- Start migrate node value if has /")
for i, node in enumerate(list(nodes)):
new_value = node.value.replace('/', '|')
print("{} start migrate node value: {} => {}".format(i, node.value, new_value))
node.value = new_value
node.save()
def migrate_nodes_full_value(apps, schema_editor):
model = apps.get_model("assets", "Node")
db_alias = schema_editor.connection.alias
nodes = model.objects.using(db_alias).all()
print("- Start migrate node full value")
for i, node in enumerate(list(nodes)):
print("{} start migrate {} node full value".format(i, node.value))
ancestor_keys = get_node_ancestor_keys(node.key, True)
values = model.objects.filter(key__in=ancestor_keys).values_list('key', 'value')
values = [v for k, v in sorted(values, key=lambda x: len(x[0]))]
node.full_value = '/' + '/'.join(values)
node.save()
class Migration(migrations.Migration):
dependencies = [
('assets', '0059_auto_20201027_1905'),
]
operations = [
migrations.AddField(
model_name='node',
name='full_value',
field=models.CharField(default='', max_length=4096, verbose_name='Full value'),
),
migrations.RunPython(migrate_nodes_value_with_slash),
migrations.RunPython(migrate_nodes_full_value)
]

View File

@@ -1,17 +0,0 @@
# Generated by Django 2.2.13 on 2020-11-16 09:57
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('assets', '0060_node_full_value'),
]
operations = [
migrations.AlterModelOptions(
name='node',
options={'ordering': ['value'], 'verbose_name': 'Node'},
),
]

View File

@@ -1,17 +0,0 @@
# Generated by Django 2.2.13 on 2020-11-17 11:38
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('assets', '0061_auto_20201116_1757'),
]
operations = [
migrations.AlterModelOptions(
name='asset',
options={'ordering': ['hostname', 'ip'], 'verbose_name': 'Asset'},
),
]

View File

@@ -47,10 +47,6 @@ class AssetManager(OrgManager):
)
class AssetOrgManager(OrgManager):
pass
class AssetQuerySet(models.QuerySet):
def active(self):
return self.filter(is_active=True)
@@ -225,12 +221,11 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
hostname_raw = models.CharField(max_length=128, blank=True, null=True, verbose_name=_('Hostname raw'))
labels = models.ManyToManyField('assets.Label', blank=True, related_name='assets', verbose_name=_("Labels"))
created_by = models.CharField(max_length=128, null=True, blank=True, verbose_name=_('Created by'))
created_by = models.CharField(max_length=32, null=True, blank=True, verbose_name=_('Created by'))
date_created = models.DateTimeField(auto_now_add=True, null=True, blank=True, verbose_name=_('Date created'))
comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
comment = models.TextField(max_length=128, default='', blank=True, verbose_name=_('Comment'))
objects = AssetManager.from_queryset(AssetQuerySet)()
org_objects = AssetOrgManager.from_queryset(AssetQuerySet)()
_connectivity = None
def __str__(self):
@@ -249,6 +244,10 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
def platform_base(self):
return self.platform.base
@lazyproperty
def admin_user_display(self):
return self.admin_user.name
@lazyproperty
def admin_user_username(self):
"""求可连接性时直接用用户名去取避免再查一次admin user
@@ -313,12 +312,6 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
}
return info
def nodes_display(self):
names = []
for n in self.nodes.all():
names.append(n.full_value)
return names
def as_node(self):
from .node import Node
fake_node = Node()
@@ -361,4 +354,36 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
class Meta:
unique_together = [('org_id', 'hostname')]
verbose_name = _("Asset")
ordering = ["hostname", "ip"]
@classmethod
def generate_fake(cls, count=100):
from .user import AdminUser, SystemUser
from random import seed, choice
from django.db import IntegrityError
from .node import Node
from orgs.utils import get_current_org
from orgs.models import Organization
org = get_current_org()
if not org or not org.is_real():
Organization.default().change_to()
nodes = list(Node.objects.all())
seed()
for i in range(count):
ip = [str(i) for i in random.sample(range(255), 4)]
asset = cls(ip='.'.join(ip),
hostname='.'.join(ip),
admin_user=choice(AdminUser.objects.all()),
created_by='Fake')
try:
asset.save()
asset.protocols = 'ssh/22'
if nodes and len(nodes) > 3:
_nodes = random.sample(nodes, 3)
else:
_nodes = [Node.default_node()]
asset.nodes.set(_nodes)
logger.debug('Generate fake asset : %s' % asset.ip)
except IntegrityError:
print('Error continue')
continue

View File

@@ -3,6 +3,7 @@
from django.db import models, transaction
from django.db.models import Max
from django.core.cache import cache
from django.utils.translation import ugettext_lazy as _
from orgs.mixins.models import OrgManager
@@ -58,17 +59,19 @@ class AuthBook(BaseUser):
"""
username = kwargs['username']
asset = kwargs['asset']
with transaction.atomic():
# 使用select_for_update限制并发创建相同的username、asset条目
instances = cls.objects.select_for_update().filter(username=username, asset=asset)
instances.filter(is_latest=True).update(is_latest=False)
max_version = cls.get_max_version(username, asset)
kwargs.update({
'version': max_version + 1,
'is_latest': True
})
obj = cls.objects.create(**kwargs)
return obj
key_lock = 'KEY_LOCK_CREATE_AUTH_BOOK_{}_{}'.format(username, asset.id)
with cache.lock(key_lock):
with transaction.atomic():
cls.objects.filter(
username=username, asset=asset, is_latest=True
).update(is_latest=False)
max_version = cls.get_max_version(username, asset)
kwargs.update({
'version': max_version + 1,
'is_latest': True
})
obj = cls.objects.create(**kwargs)
return obj
@property
def connectivity(self):

View File

@@ -158,11 +158,9 @@ class AuthMixin:
if update_fields:
self.save(update_fields=update_fields)
def has_special_auth(self, asset=None, username=None):
def has_special_auth(self, asset=None):
from .authbook import AuthBook
if username is None:
username = self.username
queryset = AuthBook.objects.filter(username=username)
queryset = AuthBook.objects.filter(username=self.username)
if asset:
queryset = queryset.filter(asset=asset)
return queryset.exists()
@@ -232,7 +230,7 @@ class AuthMixin:
class BaseUser(OrgModelMixin, AuthMixin, ConnectivityMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=128, verbose_name=_('Name'))
username = models.CharField(max_length=128, blank=True, verbose_name=_('Username'), validators=[alphanumeric], db_index=True)
username = models.CharField(max_length=32, blank=True, verbose_name=_('Username'), validators=[alphanumeric], db_index=True)
password = fields.EncryptCharField(max_length=256, blank=True, null=True, verbose_name=_('Password'))
private_key = fields.EncryptTextField(blank=True, null=True, verbose_name=_('SSH private key'))
public_key = fields.EncryptTextField(blank=True, null=True, verbose_name=_('SSH public key'))

View File

@@ -38,3 +38,27 @@ class Cluster(models.Model):
class Meta:
ordering = ['name']
verbose_name = _("Cluster")
@classmethod
def generate_fake(cls, count=5):
from random import seed, choice
import forgery_py
from django.db import IntegrityError
seed()
for i in range(count):
cluster = cls(name=forgery_py.name.full_name(),
bandwidth='200M',
contact=forgery_py.name.full_name(),
phone=forgery_py.address.phone(),
address=forgery_py.address.city() + forgery_py.address.street_address(),
# operator=choice(['北京联通', '北京电信', 'BGP全网通']),
operator=choice([_('Beijing unicom'), _('Beijing telecom'), _('BGP full netcom')]),
comment=forgery_py.lorem_ipsum.sentence(),
created_by='Fake')
try:
cluster.save()
logger.debug('Generate fake asset group: %s' % cluster.name)
except IntegrityError:
print('Error continue')
continue

View File

@@ -18,7 +18,7 @@ __all__ = [
class CommandFilter(OrgModelMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=64, verbose_name=_("Name"))
name = models.CharField(max_length=64, unique=True, verbose_name=_("Name"))
is_active = models.BooleanField(default=True, verbose_name=_('Is active'))
comment = models.TextField(blank=True, default='', verbose_name=_("Comment"))
date_created = models.DateTimeField(auto_now_add=True)
@@ -29,7 +29,6 @@ class CommandFilter(OrgModelMixin):
return self.name
class Meta:
unique_together = [('org_id', 'name')]
verbose_name = _("Command filter")
@@ -52,7 +51,7 @@ class CommandFilterRule(OrgModelMixin):
type = models.CharField(max_length=16, default=TYPE_COMMAND, choices=TYPE_CHOICES, verbose_name=_("Type"))
priority = models.IntegerField(default=50, verbose_name=_("Priority"), help_text=_("1-100, the higher will be match first"),
validators=[MinValueValidator(1), MaxValueValidator(100)])
content = models.TextField(verbose_name=_("Content"), help_text=_("One line one command"))
content = models.TextField(max_length=1024, verbose_name=_("Content"), help_text=_("One line one command"))
action = models.IntegerField(default=ACTION_DENY, choices=ACTION_CHOICES, verbose_name=_("Action"))
comment = models.CharField(max_length=64, blank=True, default='', verbose_name=_("Comment"))
date_created = models.DateTimeField(auto_now_add=True)

View File

@@ -9,7 +9,6 @@ import paramiko
from django.db import models
from django.utils.translation import ugettext_lazy as _
from common.utils.strings import no_special_chars
from orgs.mixins.models import OrgModelMixin
from .base import BaseUser
@@ -18,14 +17,13 @@ __all__ = ['Domain', 'Gateway']
class Domain(OrgModelMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=128, verbose_name=_('Name'))
name = models.CharField(max_length=128, unique=True, verbose_name=_('Name'))
comment = models.TextField(blank=True, verbose_name=_('Comment'))
date_created = models.DateTimeField(auto_now_add=True, null=True,
verbose_name=_('Date created'))
class Meta:
verbose_name = _("Domain")
unique_together = [('org_id', 'name')]
def __str__(self):
return self.name
@@ -48,7 +46,7 @@ class Gateway(BaseUser):
(PROTOCOL_SSH, 'ssh'),
(PROTOCOL_RDP, 'rdp'),
)
ip = models.CharField(max_length=128, verbose_name=_('IP'), db_index=True)
ip = models.GenericIPAddressField(max_length=32, verbose_name=_('IP'), db_index=True)
port = models.IntegerField(default=22, verbose_name=_('Port'))
protocol = models.CharField(choices=PROTOCOL_CHOICES, max_length=16, default=PROTOCOL_SSH, verbose_name=_("Protocol"))
domain = models.ForeignKey(Domain, on_delete=models.CASCADE, verbose_name=_("Domain"))
@@ -65,8 +63,8 @@ class Gateway(BaseUser):
def test_connective(self, local_port=None):
if local_port is None:
local_port = self.port
if self.password and not no_special_chars(self.password):
return False, _("Password should not contains special characters")
if self.password and not re.match(r'\w+$', self.password):
return False, _("Password should not contain special characters")
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

View File

@@ -18,15 +18,3 @@ class FavoriteAsset(CommonModelMixin):
@classmethod
def get_user_favorite_assets_id(cls, user):
return cls.objects.filter(user=user).values_list('asset', flat=True)
@classmethod
def get_user_favorite_assets(cls, user, asset_perms_id=None):
from assets.models import Asset
from perms.utils.asset.user_permission import get_user_granted_all_assets
asset_ids = get_user_granted_all_assets(
user,
via_mapping_node=False,
asset_perms_id=asset_perms_id
).values_list('id', flat=True)
query_name = cls.asset.field.related_query_name()
return Asset.org_objects.filter(**{f'{query_name}__user_id': user.id}, id__in=asset_ids).distinct()

View File

@@ -33,3 +33,21 @@ class AssetGroup(models.Model):
def initial(cls):
asset_group = cls(name=_('Default'), comment=_('Default asset group'))
asset_group.save()
@classmethod
def generate_fake(cls, count=100):
from random import seed
import forgery_py
from django.db import IntegrityError
seed()
for i in range(count):
group = cls(name=forgery_py.name.full_name(),
comment=forgery_py.lorem_ipsum.sentence(),
created_by='Fake')
try:
group.save()
logger.debug('Generate fake asset group: %s' % group.name)
except IntegrityError:
print('Error continue')
continue

View File

@@ -2,35 +2,132 @@
#
import uuid
import re
import time
from django.db import models, transaction
from django.db.models import Q
from django.db.utils import IntegrityError
from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ugettext
from django.db.transaction import atomic
from django.core.cache import cache
from common.utils import get_logger
from common.utils.common import lazyproperty
from common.utils import get_logger, lazyproperty
from orgs.mixins.models import OrgModelMixin, OrgManager
from orgs.utils import get_current_org, tmp_to_org
from orgs.utils import get_current_org, tmp_to_org, current_org
from orgs.models import Organization
__all__ = ['Node', 'FamilyMixin', 'compute_parent_key']
__all__ = ['Node']
logger = get_logger(__name__)
def compute_parent_key(key):
try:
return key[:key.rindex(':')]
except ValueError:
return ''
class NodeQuerySet(models.QuerySet):
def delete(self):
raise NotImplementedError
raise PermissionError("Bulk delete node deny")
class TreeCache:
updated_time_cache_key = 'NODE_TREE_UPDATED_AT_{}'
cache_time = 3600
assets_updated_time_cache_key = 'NODE_TREE_ASSETS_UPDATED_AT_{}'
def __init__(self, tree, org_id):
now = time.time()
self.created_time = now
self.assets_created_time = now
self.tree = tree
self.org_id = org_id
def _has_changed(self, tp="tree"):
if tp == "assets":
key = self.assets_updated_time_cache_key.format(self.org_id)
else:
key = self.updated_time_cache_key.format(self.org_id)
updated_time = cache.get(key, 0)
if updated_time > self.created_time:
return True
else:
return False
@classmethod
def set_changed(cls, tp="tree", t=None, org_id=None):
if org_id is None:
org_id = current_org.id
if tp == "assets":
key = cls.assets_updated_time_cache_key.format(org_id)
else:
key = cls.updated_time_cache_key.format(org_id)
ttl = cls.cache_time
if not t:
t = time.time()
cache.set(key, t, ttl)
def tree_has_changed(self):
return self._has_changed("tree")
def set_tree_changed(self, t=None):
logger.debug("Set tree tree changed")
self.__class__.set_changed(t=t, tp="tree")
def assets_has_changed(self):
return self._has_changed("assets")
def set_tree_assets_changed(self, t=None):
logger.debug("Set tree assets changed")
self.__class__.set_changed(t=t, tp="assets")
def get(self):
if self.tree_has_changed():
self.renew()
return self.tree
if self.assets_has_changed():
self.tree.init_assets()
return self.tree
def renew(self):
new_obj = self.__class__.new(self.org_id)
self.tree = new_obj.tree
self.created_time = new_obj.created_time
self.assets_created_time = new_obj.assets_created_time
@classmethod
def new(cls, org_id=None):
from ..utils import TreeService
logger.debug("Create node tree")
if not org_id:
org_id = current_org.id
with tmp_to_org(org_id):
tree = TreeService.new()
obj = cls(tree, org_id)
obj.tree = tree
return obj
class TreeMixin:
_org_tree_map = {}
@classmethod
def tree(cls):
org_id = current_org.org_id()
t = cls.get_local_tree_cache(org_id)
if t is None:
t = TreeCache.new()
cls._org_tree_map[org_id] = t
return t.get()
@classmethod
def get_local_tree_cache(cls, org_id=None):
t = cls._org_tree_map.get(org_id)
return t
@classmethod
def refresh_tree(cls, t=None):
TreeCache.set_changed(tp="tree", t=t, org_id=current_org.id)
@classmethod
def refresh_node_assets(cls, t=None):
TreeCache.set_changed(tp="assets", t=t, org_id=current_org.id)
class FamilyMixin:
@@ -41,16 +138,16 @@ class FamilyMixin:
@staticmethod
def clean_children_keys(nodes_keys):
sort_key = lambda k: [int(i) for i in k.split(':')]
nodes_keys = sorted(list(nodes_keys), key=sort_key)
nodes_keys = sorted(list(nodes_keys), key=lambda x: (len(x), x))
nodes_keys_clean = []
base_key = ''
for key in nodes_keys:
if key.startswith(base_key + ':'):
continue
nodes_keys_clean.append(key)
base_key = key
for key in nodes_keys[::-1]:
found = False
for k in nodes_keys:
if key.startswith(k + ':'):
found = True
break
if not found:
nodes_keys_clean.append(key)
return nodes_keys_clean
@classmethod
@@ -78,16 +175,13 @@ class FamilyMixin:
return re.match(children_pattern, self.key)
def get_children(self, with_self=False):
q = Q(parent_key=self.key)
if with_self:
q |= Q(key=self.key)
return Node.objects.filter(q)
pattern = self.get_children_key_pattern(with_self=with_self)
return Node.objects.filter(key__regex=pattern)
def get_all_children(self, with_self=False):
q = Q(key__istartswith=f'{self.key}:')
if with_self:
q |= Q(key=self.key)
return Node.objects.filter(q)
pattern = self.get_all_children_pattern(with_self=with_self)
children = Node.objects.filter(key__regex=pattern)
return children
@property
def children(self):
@@ -97,30 +191,14 @@ class FamilyMixin:
def all_children(self):
return self.get_all_children(with_self=False)
def create_child(self, value=None, _id=None):
with atomic(savepoint=False):
def create_child(self, value, _id=None):
with transaction.atomic():
child_key = self.get_next_child_key()
if value is None:
value = child_key
child = self.__class__.objects.create(
id=_id, key=child_key, value=value, parent_key=self.key,
id=_id, key=child_key, value=value
)
return child
def get_or_create_child(self, value, _id=None):
"""
:return: Node, bool (created)
"""
children = self.get_children()
exist = children.filter(value=value).exists()
if exist:
child = children.filter(value=value).first()
created = False
else:
child = self.create_child(value, _id)
created = True
return child, created
def get_next_child_key(self):
mark = self.child_mark
self.child_mark += 1
@@ -163,13 +241,10 @@ class FamilyMixin:
ancestor_keys = self.get_ancestor_keys(with_self=with_self)
return self.__class__.objects.filter(key__in=ancestor_keys)
# @property
# def parent_key(self):
# parent_key = ":".join(self.key.split(":")[:-1])
# return parent_key
def compute_parent_key(self):
return compute_parent_key(self.key)
@property
def parent_key(self):
parent_key = ":".join(self.key.split(":")[:-1])
return parent_key
def is_parent(self, other):
return other.is_children(self)
@@ -205,63 +280,45 @@ class FamilyMixin:
sibling = sibling.exclude(key=self.key)
return sibling
@classmethod
def create_node_by_full_value(cls, full_value):
if not full_value:
return []
nodes_family = full_value.split('/')
nodes_family = [v for v in nodes_family if v]
org_root = cls.org_root()
if nodes_family[0] == org_root.value:
nodes_family = nodes_family[1:]
return cls.create_nodes_recurse(nodes_family, org_root)
@classmethod
def create_nodes_recurse(cls, values, parent=None):
values = [v for v in values if v]
if not values:
return None
if parent is None:
parent = cls.org_root()
value = values[0]
child, created = parent.get_or_create_child(value=value)
if len(values) == 1:
return child
return cls.create_nodes_recurse(values[1:], child)
def get_family(self):
ancestors = self.get_ancestors()
children = self.get_all_children()
return [*tuple(ancestors), self, *tuple(children)]
class FullValueMixin:
key = ''
@lazyproperty
def full_value(self):
if self.is_org_root():
return self.value
value = self.tree().get_node_full_tag(self.key)
return value
class NodeAssetsMixin:
key = ''
id = None
@lazyproperty
def assets_amount(self):
amount = self.tree().assets_amount(self.key)
return amount
def get_all_assets(self):
from .asset import Asset
q = Q(nodes__key__startswith=f'{self.key}:') | Q(nodes__key=self.key)
return Asset.objects.filter(q).distinct()
@classmethod
def get_node_all_assets_by_key_v2(cls, key):
# 最初的写法是:
# Asset.objects.filter(Q(nodes__key__startswith=f'{node.key}:') | Q(nodes__id=node.id))
# 可是 startswith 会导致表关联时 Asset 索引失效
from .asset import Asset
node_ids = cls.objects.filter(
Q(key__startswith=f'{key}:') |
Q(key=key)
).values_list('id', flat=True).distinct()
assets = Asset.objects.filter(
nodes__id__in=list(node_ids)
).distinct()
return assets
if self.is_org_root():
return Asset.objects.filter(org_id=self.org_id)
pattern = '^{0}$|^{0}:'.format(self.key)
return Asset.objects.filter(nodes__key__regex=pattern).distinct()
def get_assets(self):
from .asset import Asset
assets = Asset.objects.filter(nodes=self)
if self.is_org_root():
assets = Asset.objects.filter(Q(nodes=self) | Q(nodes__isnull=True))
else:
assets = Asset.objects.filter(nodes=self)
return assets.distinct()
def get_valid_assets(self):
@@ -270,54 +327,51 @@ class NodeAssetsMixin:
def get_all_valid_assets(self):
return self.get_all_assets().valid()
@classmethod
def _get_nodes_all_assets(cls, nodes_keys):
"""
当节点比较多的时候,这种正则方式性能差极了
:param nodes_keys:
:return:
"""
from .asset import Asset
nodes_keys = cls.clean_children_keys(nodes_keys)
nodes_children_pattern = set()
for key in nodes_keys:
children_pattern = cls.get_node_all_children_key_pattern(key)
nodes_children_pattern.add(children_pattern)
pattern = '|'.join(nodes_children_pattern)
return Asset.objects.filter(nodes__key__regex=pattern).distinct()
@classmethod
def get_nodes_all_assets_ids(cls, nodes_keys):
assets_ids = cls.get_nodes_all_assets(nodes_keys).values_list('id', flat=True)
nodes_keys = cls.clean_children_keys(nodes_keys)
assets_ids = set()
for key in nodes_keys:
node_assets_ids = cls.tree().all_assets(key)
assets_ids.update(set(node_assets_ids))
return assets_ids
@classmethod
def get_nodes_all_assets(cls, nodes_keys, extra_assets_ids=None):
from .asset import Asset
nodes_keys = cls.clean_children_keys(nodes_keys)
q = Q()
node_ids = ()
for key in nodes_keys:
q |= Q(key__startswith=f'{key}:')
q |= Q(key=key)
if q:
node_ids = Node.objects.filter(q).distinct().values_list('id', flat=True)
q = Q(nodes__id__in=list(node_ids))
assets_ids = cls.get_nodes_all_assets_ids(nodes_keys)
if extra_assets_ids:
q |= Q(id__in=extra_assets_ids)
if q:
return Asset.org_objects.filter(q).distinct()
else:
return Asset.objects.none()
assets_ids.update(set(extra_assets_ids))
return Asset.objects.filter(id__in=assets_ids)
class SomeNodesMixin:
key = ''
default_key = '1'
default_value = 'Default'
ungrouped_key = '-10'
ungrouped_value = _('ungrouped')
empty_key = '-11'
empty_value = _("empty")
@classmethod
def default_node(cls):
with tmp_to_org(Organization.default()):
defaults = {'value': cls.default_value}
try:
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key,
)
except IntegrityError as e:
logger.error("Create default node failed: {}".format(e))
cls.modify_other_org_root_node_key()
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key,
)
return obj
favorite_key = '-12'
favorite_value = _("favorite")
def is_default_node(self):
return self.key == self.default_key
@@ -352,17 +406,51 @@ class SomeNodesMixin:
@classmethod
def org_root(cls):
root = cls.objects.filter(parent_key='')\
.filter(key__regex=r'^[0-9]+$')\
.exclude(key__startswith='-')
root = cls.objects.filter(key__regex=r'^[0-9]+$')
if root:
return root[0]
else:
return cls.create_org_root_node()
@classmethod
def ungrouped_node(cls):
with tmp_to_org(Organization.system()):
defaults = {'value': cls.ungrouped_value}
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.ungrouped_key
)
return obj
@classmethod
def default_node(cls):
with tmp_to_org(Organization.default()):
defaults = {'value': cls.default_value}
try:
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key,
)
except IntegrityError as e:
logger.error("Create default node failed: {}".format(e))
cls.modify_other_org_root_node_key()
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key,
)
return obj
@classmethod
def favorite_node(cls):
with tmp_to_org(Organization.system()):
defaults = {'value': cls.favorite_value}
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.favorite_key
)
return obj
@classmethod
def initial_some_nodes(cls):
cls.default_node()
cls.ungrouped_node()
cls.favorite_node()
@classmethod
def modify_other_org_root_node_key(cls):
@@ -394,16 +482,12 @@ class SomeNodesMixin:
logger.info('Modify key ( {} > {} )'.format(old_key, new_key))
class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
class Node(OrgModelMixin, SomeNodesMixin, TreeMixin, FamilyMixin, FullValueMixin, NodeAssetsMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
key = models.CharField(unique=True, max_length=64, verbose_name=_("Key")) # '1:1:1:1'
value = models.CharField(max_length=128, verbose_name=_("Value"))
full_value = models.CharField(max_length=4096, verbose_name=_('Full value'), default='')
child_mark = models.IntegerField(default=0)
date_create = models.DateTimeField(auto_now_add=True)
parent_key = models.CharField(max_length=64, verbose_name=_("Parent key"),
db_index=True, default='')
assets_amount = models.IntegerField(default=0)
objects = OrgManager.from_queryset(NodeQuerySet)()
is_node = True
@@ -411,10 +495,10 @@ class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
class Meta:
verbose_name = _("Node")
ordering = ['value']
ordering = ['key']
def __str__(self):
return self.full_value
return self.value
# def __eq__(self, other):
# if not other:
@@ -438,19 +522,18 @@ class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
def name(self):
return self.value
def computed_full_value(self):
# 不要在列表中调用该属性
values = self.__class__.objects.filter(
key__in=self.get_ancestor_keys()
).values_list('key', 'value')
values = [v for k, v in sorted(values, key=lambda x: len(x[0]))]
values.append(str(self.value))
return '/' + '/'.join(values)
@property
def level(self):
return len(self.key.split(':'))
@classmethod
def refresh_nodes(cls):
cls.refresh_tree()
@classmethod
def refresh_assets(cls):
cls.refresh_node_assets()
def as_tree_node(self):
from common.tree import TreeNode
name = '{} ({})'.format(self.value, self.assets_amount)
@@ -485,21 +568,19 @@ class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
return
return super().delete(using=using, keep_parents=keep_parents)
def update_child_full_value(self):
nodes = self.get_all_children(with_self=True)
sort_key_func = lambda n: [int(i) for i in n.key.split(':')]
nodes_sorted = sorted(list(nodes), key=sort_key_func)
nodes_mapper = {n.key: n for n in nodes_sorted}
for node in nodes_sorted:
parent = nodes_mapper.get(node.parent_key)
if not parent:
logger.error(f'Node parent node in mapper: {node.parent_key} {node.value}')
continue
node.full_value = parent.full_value + '/' + node.value
self.__class__.objects.bulk_update(nodes, ['full_value'])
@classmethod
def generate_fake(cls, count=100):
import random
org = get_current_org()
if not org or not org.is_real():
Organization.default().change_to()
nodes = list(cls.objects.all())
if count > 100:
length = 100
else:
length = count
def save(self, *args, **kwargs):
self.full_value = self.computed_full_value()
instance = super().save(*args, **kwargs)
self.update_child_full_value()
return instance
for i in range(length):
node = random.choice(nodes)
child = node.create_child('Node {}'.format(i))
print("{}. {}".format(i, child))

View File

@@ -9,7 +9,6 @@ from django.utils.translation import ugettext_lazy as _
from django.core.validators import MinValueValidator, MaxValueValidator
from common.utils import signer
from common.fields.model import JsonListCharField
from .base import BaseUser
from .asset import Asset
@@ -65,6 +64,26 @@ class AdminUser(BaseUser):
unique_together = [('name', 'org_id')]
verbose_name = _("Admin user")
@classmethod
def generate_fake(cls, count=10):
from random import seed
import forgery_py
from django.db import IntegrityError
seed()
for i in range(count):
obj = cls(name=forgery_py.name.full_name(),
username=forgery_py.internet.user_name(),
password=forgery_py.lorem_ipsum.word(),
comment=forgery_py.lorem_ipsum.sentence(),
created_by='Fake')
try:
obj.save()
logger.debug('Generate fake asset group: %s' % obj.name)
except IntegrityError:
print('Error continue')
continue
class SystemUser(BaseUser):
PROTOCOL_SSH = 'ssh'
@@ -72,20 +91,12 @@ class SystemUser(BaseUser):
PROTOCOL_TELNET = 'telnet'
PROTOCOL_VNC = 'vnc'
PROTOCOL_MYSQL = 'mysql'
PROTOCOL_ORACLE = 'oracle'
PROTOCOL_MARIADB = 'mariadb'
PROTOCOL_POSTGRESQL = 'postgresql'
PROTOCOL_K8S = 'k8s'
PROTOCOL_CHOICES = (
(PROTOCOL_SSH, 'ssh'),
(PROTOCOL_RDP, 'rdp'),
(PROTOCOL_TELNET, 'telnet'),
(PROTOCOL_VNC, 'vnc'),
(PROTOCOL_MYSQL, 'mysql'),
(PROTOCOL_ORACLE, 'oracle'),
(PROTOCOL_MARIADB, 'mariadb'),
(PROTOCOL_POSTGRESQL, 'postgresql'),
(PROTOCOL_K8S, 'k8s'),
)
LOGIN_AUTO = 'auto'
@@ -107,10 +118,6 @@ class SystemUser(BaseUser):
login_mode = models.CharField(choices=LOGIN_MODE_CHOICES, default=LOGIN_AUTO, max_length=10, verbose_name=_('Login mode'))
cmd_filters = models.ManyToManyField('CommandFilter', related_name='system_users', verbose_name=_("Command filter"), blank=True)
sftp_root = models.CharField(default='tmp', max_length=128, verbose_name=_("SFTP Root"))
token = models.TextField(default='', verbose_name=_('Token'))
home = models.CharField(max_length=4096, default='', verbose_name=_('Home'), blank=True)
system_groups = models.CharField(default='', max_length=4096, verbose_name=_('System groups'), blank=True)
ad_domain = models.CharField(default='', max_length=256)
_prefer = 'system_user'
def __str__(self):
@@ -133,24 +140,6 @@ class SystemUser(BaseUser):
def login_mode_display(self):
return self.get_login_mode_display()
@property
def db_application_protocols(self):
return [
self.PROTOCOL_MYSQL, self.PROTOCOL_ORACLE, self.PROTOCOL_MARIADB,
self.PROTOCOL_POSTGRESQL
]
@property
def cloud_application_protocols(self):
return [self.PROTOCOL_K8S]
@property
def application_category_protocols(self):
protocols = []
protocols.extend(self.db_application_protocols)
protocols.extend(self.cloud_application_protocols)
return protocols
def is_need_push(self):
if self.auto_push and self.protocol in [self.PROTOCOL_SSH, self.PROTOCOL_RDP]:
return True
@@ -163,16 +152,11 @@ class SystemUser(BaseUser):
@property
def is_need_test_asset_connective(self):
return self.protocol not in self.application_category_protocols
def has_special_auth(self, asset=None, username=None):
if username is None and self.username_same_with_user:
raise TypeError('System user is dynamic, username should be pass')
return super().has_special_auth(asset=asset, username=username)
return self.protocol not in [self.PROTOCOL_MYSQL]
@property
def can_perm_to_asset(self):
return self.protocol not in self.application_category_protocols
return self.protocol not in [self.PROTOCOL_MYSQL]
def _merge_auth(self, other):
super()._merge_auth(other)
@@ -209,3 +193,23 @@ class SystemUser(BaseUser):
ordering = ['name']
unique_together = [('name', 'org_id')]
verbose_name = _("System user")
@classmethod
def generate_fake(cls, count=10):
from random import seed
import forgery_py
from django.db import IntegrityError
seed()
for i in range(count):
obj = cls(name=forgery_py.name.full_name(),
username=forgery_py.internet.user_name(),
password=forgery_py.lorem_ipsum.word(),
comment=forgery_py.lorem_ipsum.sentence(),
created_by='Fake')
try:
obj.save()
logger.debug('Generate fake asset group: %s' % obj.name)
except IntegrityError:
print('Error continue')
continue

View File

@@ -1,39 +0,0 @@
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.request import Request
from assets.models import Node
class AssetLimitOffsetPagination(LimitOffsetPagination):
"""
需要与 `assets.api.mixin.FilterAssetByNodeMixin` 配合使用
"""
def get_count(self, queryset):
"""
1. 如果查询节点下的所有资产,那 count 使用 Node.assets_amount
2. 如果有其他过滤条件使用 super
3. 如果只查询该节点下的资产使用 super
"""
exclude_query_params = {
self.limit_query_param,
self.offset_query_param,
'node', 'all', 'show_current_asset',
'node_id', 'display', 'draw', 'fields_size',
}
for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None:
return super().get_count(queryset)
is_query_all = self._view.is_query_node_all_assets
if is_query_all:
node = self._view.node
if not node:
node = Node.org_root()
return node.assets_amount
return super().get_count(queryset)
def paginate_queryset(self, queryset, request: Request, view=None):
self._request = request
self._view = view
return super().paginate_queryset(queryset, request, view=None)

View File

@@ -1,12 +1,13 @@
# -*- coding: utf-8 -*-
#
from rest_framework import serializers
from django.db.models import F
from django.db.models import Prefetch, F, Count
from django.utils.translation import ugettext_lazy as _
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from ..models import Asset, Node, Platform
from common.serializers import AdaptedBulkListSerializer
from ..models import Asset, Node, Label, Platform
from .base import ConnectivitySerializer
__all__ = [
@@ -66,15 +67,12 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
slug_field='name', queryset=Platform.objects.all(), label=_("Platform")
)
protocols = ProtocolsField(label=_('Protocols'), required=False)
domain_display = serializers.ReadOnlyField(source='domain.name', label=_('Domain name'))
admin_user_display = serializers.ReadOnlyField(source='admin_user.name', label=_('Admin user name'))
nodes_display = serializers.ListField(child=serializers.CharField(), label=_('Nodes name'), required=False)
"""
资产的数据结构
"""
class Meta:
model = Asset
list_serializer_class = AdaptedBulkListSerializer
fields_mini = ['id', 'hostname', 'ip']
fields_small = fields_mini + [
'protocol', 'port', 'protocols', 'is_active', 'public_ip',
@@ -84,13 +82,13 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
'created_by', 'date_created', 'hardware_info',
]
fields_fk = [
'admin_user', 'admin_user_display', 'domain', 'domain_display', 'platform'
'admin_user', 'admin_user_display', 'domain', 'platform'
]
fk_only_fields = {
'platform': ['name']
}
fields_m2m = [
'nodes', 'nodes_display', 'labels',
'nodes', 'labels',
]
annotates_fields = {
# 'admin_user_display': 'admin_user__name'
@@ -115,7 +113,6 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
queryset = queryset.select_related('admin_user', 'domain', 'platform')
queryset = queryset.prefetch_related('nodes', 'labels')
return queryset
def compatible_with_old_protocol(self, validated_data):
@@ -133,32 +130,14 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
if protocols_data:
validated_data["protocols"] = ' '.join(protocols_data)
def perform_nodes_display_create(self, instance, nodes_display):
if not nodes_display:
return
nodes_to_set = []
for full_value in nodes_display:
node = Node.objects.filter(full_value=full_value).first()
if node:
nodes_to_set.append(node)
else:
node = Node.create_node_by_full_value(full_value)
nodes_to_set.append(node)
instance.nodes.set(nodes_to_set)
def create(self, validated_data):
self.compatible_with_old_protocol(validated_data)
nodes_display = validated_data.pop('nodes_display', '')
instance = super().create(validated_data)
self.perform_nodes_display_create(instance, nodes_display)
return instance
def update(self, instance, validated_data):
nodes_display = validated_data.pop('nodes_display', '')
self.compatible_with_old_protocol(validated_data)
instance = super().update(instance, validated_data)
self.perform_nodes_display_create(instance, nodes_display)
return instance
return super().update(instance, validated_data)
class AssetDisplaySerializer(AssetSerializer):
@@ -171,7 +150,7 @@ class AssetDisplaySerializer(AssetSerializer):
@classmethod
def setup_eager_loading(cls, queryset):
queryset = super().setup_eager_loading(queryset)
""" Perform necessary eager loading of data. """
queryset = queryset\
.annotate(admin_user_username=F('admin_user__username'))
return queryset

View File

@@ -1,19 +1,17 @@
# -*- coding: utf-8 -*-
#
from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
from common.serializers import AdaptedBulkListSerializer
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from common.validators import NoSpecialChars
from ..models import Domain, Gateway
from .base import AuthSerializerMixin
class DomainSerializer(BulkOrgResourceModelSerializer):
asset_count = serializers.SerializerMethodField(label=_('Assets count'))
application_count = serializers.SerializerMethodField(label=_('Applications count'))
gateway_count = serializers.SerializerMethodField(label=_('Gateways count'))
asset_count = serializers.SerializerMethodField()
gateway_count = serializers.SerializerMethodField()
class Meta:
model = Domain
@@ -22,12 +20,12 @@ class DomainSerializer(BulkOrgResourceModelSerializer):
'comment', 'date_created'
]
fields_m2m = [
'asset_count', 'assets', 'application_count', 'gateway_count',
'asset_count', 'assets', 'gateway_count',
]
fields = fields_small + fields_m2m
read_only_fields = ('asset_count', 'gateway_count', 'date_created')
extra_kwargs = {
'assets': {'required': False, 'label': _('Assets')},
'assets': {'required': False}
}
list_serializer_class = AdaptedBulkListSerializer
@@ -35,10 +33,6 @@ class DomainSerializer(BulkOrgResourceModelSerializer):
def get_asset_count(obj):
return obj.assets.count()
@staticmethod
def get_application_count(obj):
return obj.applications.count()
@staticmethod
def get_gateway_count(obj):
return obj.gateway_set.all().count()
@@ -53,9 +47,6 @@ class GatewaySerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer):
'private_key', 'public_key', 'domain', 'is_active', 'date_created',
'date_updated', 'created_by', 'comment',
]
extra_kwargs = {
'password': {'validators': [NoSpecialChars()]}
}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@@ -25,9 +25,6 @@ class NodeSerializer(BulkOrgResourceModelSerializer):
read_only_fields = ['key', 'org_id']
def validate_value(self, data):
if '/' in data:
error = _("Can't contains: " + "/")
raise serializers.ValidationError(error)
if self.instance:
instance = self.instance
siblings = instance.get_siblings()

View File

@@ -6,6 +6,7 @@ from common.serializers import AdaptedBulkListSerializer
from common.mixins.serializers import BulkSerializerMixin
from common.utils import ssh_pubkey_gen
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from assets.models import Node
from ..models import SystemUser, Asset
from .base import AuthSerializerMixin
@@ -32,20 +33,17 @@ class SystemUserSerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer):
'login_mode', 'login_mode_display',
'priority', 'username_same_with_user',
'auto_push', 'cmd_filters', 'sudo', 'shell', 'comment',
'auto_generate_key', 'sftp_root', 'token',
'assets_amount', 'date_created', 'created_by',
'home', 'system_groups', 'ad_domain'
'auto_generate_key', 'sftp_root',
'assets_amount', 'date_created', 'created_by'
]
extra_kwargs = {
'password': {"write_only": True},
'public_key': {"write_only": True},
'private_key': {"write_only": True},
'token': {"write_only": True},
'nodes_amount': {'label': _('Nodes amount')},
'assets_amount': {'label': _('Assets amount')},
'nodes_amount': {'label': _('Node')},
'assets_amount': {'label': _('Asset')},
'login_mode_display': {'label': _('Login mode display')},
'created_by': {'read_only': True},
'ad_domain': {'required': False, 'allow_blank': True, 'label': _('Ad domain')},
}
def validate_auto_push(self, value):
@@ -145,28 +143,16 @@ class SystemUserSerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer):
class SystemUserListSerializer(SystemUserSerializer):
class Meta(SystemUserSerializer.Meta):
fields = [
'id', 'name', 'username', 'protocol',
'password', 'public_key', 'private_key',
'login_mode', 'login_mode_display',
'priority', "username_same_with_user",
'auto_push', 'sudo', 'shell', 'comment',
"assets_amount", 'home', 'system_groups',
'auto_generate_key', 'ad_domain',
"assets_amount",
'auto_generate_key',
'sftp_root',
]
extra_kwargs = {
'password': {"write_only": True},
'public_key': {"write_only": True},
'private_key': {"write_only": True},
'nodes_amount': {'label': _('Nodes amount')},
'assets_amount': {'label': _('Assets amount')},
'login_mode_display': {'label': _('Login mode display')},
'created_by': {'read_only': True},
'ad_domain': {'label': _('Ad domain')},
}
@classmethod
def setup_eager_loading(cls, queryset):
@@ -183,8 +169,7 @@ class SystemUserWithAuthInfoSerializer(SystemUserSerializer):
'login_mode', 'login_mode_display',
'priority', 'username_same_with_user',
'auto_push', 'sudo', 'shell', 'comment',
'auto_generate_key', 'sftp_root', 'token',
'ad_domain',
'auto_generate_key', 'sftp_root',
]
extra_kwargs = {
'nodes_amount': {'label': _('Node')},
@@ -234,8 +219,15 @@ class SystemUserNodeRelationSerializer(RelationMixin, serializers.ModelSerialize
'id', 'node', "node_display",
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tree = Node.tree()
def get_node_display(self, obj):
return obj.node.full_value
if hasattr(obj, 'node_key'):
return self.tree.get_node_full_tag(obj.node_key)
else:
return obj.node.full_value
class SystemUserUserRelationSerializer(RelationMixin, serializers.ModelSerializer):

View File

@@ -1,20 +1,17 @@
# -*- coding: utf-8 -*-
#
from operator import add, sub
from assets.utils import is_asset_exists_in_node
from collections import defaultdict
from django.db.models.signals import (
post_save, m2m_changed, pre_delete, post_delete
post_save, m2m_changed, post_delete
)
from django.db.models import Q, F
from django.db.models.aggregates import Count
from django.dispatch import receiver
from common.exceptions import M2MReverseNotAllowed
from common.const.signals import PRE_ADD, POST_ADD, POST_REMOVE, PRE_CLEAR, PRE_REMOVE
from common.utils import get_logger
from common.decorator import on_transaction_commit
from .models import Asset, SystemUser, Node, compute_parent_key
from users.models import User
from orgs.utils import tmp_to_root_org
from .models import Asset, SystemUser, Node, AuthBook
from .utils import TreeService
from .tasks import (
update_assets_hardware_info_util,
test_asset_connectivity_util,
@@ -57,9 +54,17 @@ def on_asset_created_or_update(sender, instance=None, created=False, **kwargs):
instance.nodes.add(Node.org_root())
@receiver(post_delete, sender=Asset)
def on_asset_delete(sender, instance=None, **kwargs):
"""
当资产删除时,刷新节点,节点中存在节点和资产的关系
"""
logger.debug("Asset delete signal recv: {}".format(instance))
Node.refresh_assets()
@receiver(post_save, sender=SystemUser, dispatch_uid="jms")
@on_transaction_commit
def on_system_user_update(instance: SystemUser, created, **kwargs):
def on_system_user_update(sender, instance=None, created=True, **kwargs):
"""
当系统用户更新时,可能更新了秘钥,用户名等,这时要自动推送系统用户到资产上,
其实应该当 用户名,密码,秘钥 sudo等更新时再推送这里偷个懒,
@@ -69,25 +74,26 @@ def on_system_user_update(instance: SystemUser, created, **kwargs):
if instance and not created:
logger.info("System user update signal recv: {}".format(instance))
assets = instance.assets.all().valid()
push_system_user_to_assets.delay(instance.id, [_asset.id for _asset in assets])
push_system_user_to_assets.delay(instance, assets)
@receiver(m2m_changed, sender=SystemUser.assets.through)
def on_system_user_assets_change(instance, action, model, pk_set, **kwargs):
def on_system_user_assets_change(sender, instance=None, action='', model=None, pk_set=None, **kwargs):
"""
当系统用户和资产关系发生变化时,应该重新推送系统用户到新添加的资产中
"""
if action != POST_ADD:
if action != "post_add":
return
logger.debug("System user assets change signal recv: {}".format(instance))
queryset = model.objects.filter(pk__in=pk_set)
if model == Asset:
system_users_id = [instance.id]
assets_id = pk_set
system_users = [instance]
assets = queryset
else:
system_users_id = pk_set
assets_id = [instance.id]
for system_user_id in system_users_id:
push_system_user_to_assets.delay(system_user_id, assets_id)
system_users = queryset
assets = [instance]
for system_user in system_users:
push_system_user_to_assets.delay(system_user, assets)
@receiver(m2m_changed, sender=SystemUser.users.through)
@@ -95,7 +101,7 @@ def on_system_user_users_change(sender, instance=None, action='', model=None, pk
"""
当系统用户和用户关系发生变化时,应该重新推送系统用户资产中
"""
if action != POST_ADD:
if action != "post_add":
return
if not instance.username_same_with_user:
return
@@ -114,7 +120,7 @@ def on_system_user_nodes_change(sender, instance=None, action=None, model=None,
"""
当系统用户和节点关系发生变化时,应该将节点下资产关联到新的系统用户上
"""
if action != POST_ADD:
if action != "post_add":
return
logger.info("System user nodes update signal recv: {}".format(instance))
@@ -129,217 +135,102 @@ def on_system_user_nodes_change(sender, instance=None, action=None, model=None,
@receiver(m2m_changed, sender=SystemUser.groups.through)
def on_system_user_groups_change(instance, action, pk_set, reverse, **kwargs):
def on_system_user_groups_change(sender, instance=None, action=None, model=None,
pk_set=None, reverse=False, **kwargs):
"""
当系统用户和用户组关系发生变化时,应该将组下用户关联到新的系统用户上
"""
if action != POST_ADD:
if action != "post_add" or reverse:
return
if reverse:
raise M2MReverseNotAllowed
logger.info("System user groups update signal recv: {}".format(instance))
users = User.objects.filter(groups__id__in=pk_set).distinct()
instance.users.add(*users)
groups = model.objects.filter(pk__in=pk_set).annotate(users_count=Count("users"))
users = groups.filter(users_count__gt=0).values_list('users', flat=True)
instance.users.add(*tuple(users))
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_asset_nodes_add(instance, action, reverse, pk_set, **kwargs):
def on_asset_nodes_change(sender, instance=None, action='', **kwargs):
"""
本操作共访问 4 次数据库
资产节点发生变化时,刷新节点
"""
if action.startswith('post'):
logger.debug("Asset nodes change signal recv: {}".format(instance))
Node.refresh_assets()
with tmp_to_root_org():
Node.refresh_assets()
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_asset_nodes_add(sender, instance=None, action='', model=None, pk_set=None, **kwargs):
"""
当资产的节点发生变化时,或者 当节点的资产关系发生变化时,
节点下新增的资产,添加到节点关联的系统用户中
"""
if action != POST_ADD:
if action != "post_add":
return
logger.debug("Assets node add signal recv: {}".format(action))
if reverse:
nodes = [instance.key]
asset_ids = pk_set
if model == Node:
nodes = model.objects.filter(pk__in=pk_set).values_list('key', flat=True)
assets = [instance.id]
else:
nodes = Node.objects.filter(pk__in=pk_set).values_list('key', flat=True)
asset_ids = [instance.id]
nodes = [instance.key]
assets = model.objects.filter(pk__in=pk_set).values_list('id', flat=True)
# 节点资产发生变化时,将资产关联到节点及祖先节点关联的系统用户, 只关注新增的
nodes_ancestors_keys = set()
node_tree = TreeService.new()
for node in nodes:
nodes_ancestors_keys.update(Node.get_node_ancestor_keys(node, with_self=True))
ancestors_keys = node_tree.ancestors_ids(nid=node)
nodes_ancestors_keys.update(ancestors_keys)
system_users = SystemUser.objects.filter(nodes__key__in=nodes_ancestors_keys)
# 查询所有祖先节点关联的系统用户,都是要跟资产建立关系的
system_user_ids = SystemUser.objects.filter(
nodes__key__in=nodes_ancestors_keys
).distinct().values_list('id', flat=True)
# 查询所有已存在的关系
m2m_model = SystemUser.assets.through
exist = set(m2m_model.objects.filter(
systemuser_id__in=system_user_ids, asset_id__in=asset_ids
).values_list('systemuser_id', 'asset_id'))
# TODO 优化
to_create = []
for system_user_id in system_user_ids:
asset_ids_to_push = []
for asset_id in asset_ids:
if (system_user_id, asset_id) in exist:
continue
asset_ids_to_push.append(asset_id)
to_create.append(m2m_model(
systemuser_id=system_user_id,
asset_id=asset_id
))
push_system_user_to_assets.delay(system_user_id, asset_ids_to_push)
m2m_model.objects.bulk_create(to_create)
def _update_node_assets_amount(node: Node, asset_pk_set: set, operator=add):
"""
一个节点与多个资产关系变化时,更新计数
:param node: 节点实例
:param asset_pk_set: 资产的`id`集合, 内部不会修改该值
:param operator: 操作
* -> Node
# -> Asset
* [3]
/ \
* * [2]
/ \
* * [1]
/ / \
* [a] # # [b]
"""
# 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key`
ancestor_keys = node.get_ancestor_keys(with_self=True)
ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key')
to_update = []
for ancestor in ancestors:
# 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3]
# 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的
# 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量
asset_pk_set -= set(Asset.objects.filter(
id__in=asset_pk_set
).filter(
Q(nodes__key__istartswith=f'{ancestor.key}:') |
Q(nodes__key=ancestor.key)
).distinct().values_list('id', flat=True))
if not asset_pk_set:
# 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量
# 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用
# 处理了
break
ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set))
to_update.append(ancestor)
Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key'))
def _remove_ancestor_keys(ancestor_key, tree_set):
# 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while ancestor_key and ancestor_key in tree_set:
tree_set.remove(ancestor_key)
ancestor_key = compute_parent_key(ancestor_key)
def _update_nodes_asset_amount(node_keys, asset_pk, operator):
"""
一个资产与多个节点关系变化时,更新计数
:param node_keys: 节点 id 的集合
:param asset_pk: 资产 id
:param operator: 操作
"""
# 所有相关节点的祖先节点,组成一棵局部树
ancestor_keys = set()
for key in node_keys:
ancestor_keys.update(Node.get_node_ancestor_keys(key))
# 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉
node_keys -= ancestor_keys
to_update_keys = []
for key in node_keys:
# 遍历相关节点,处理它及其祖先节点
# 查询该节点是否包含待处理资产
exists = is_asset_exists_in_node(asset_pk, key)
parent_key = compute_parent_key(key)
if exists:
# 如果资产在该节点,那么他及其祖先节点都不用处理
_remove_ancestor_keys(parent_key, ancestor_keys)
continue
else:
# 不存在,要更新本节点
to_update_keys.append(key)
# 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while parent_key and parent_key in ancestor_keys:
exists = is_asset_exists_in_node(asset_pk, parent_key)
if exists:
_remove_ancestor_keys(parent_key, ancestor_keys)
break
else:
to_update_keys.append(parent_key)
ancestor_keys.remove(parent_key)
parent_key = compute_parent_key(parent_key)
Node.objects.filter(key__in=to_update_keys).update(
assets_amount=operator(F('assets_amount'), 1)
)
system_users_assets = defaultdict(set)
for system_user in system_users:
system_users_assets[system_user].update(set(assets))
for system_user, _assets in system_users_assets.items():
system_user.assets.add(*tuple(_assets))
@receiver(m2m_changed, sender=Asset.nodes.through)
def update_nodes_assets_amount(action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
refused = (PRE_CLEAR,)
if action in refused:
raise ValueError
def on_asset_nodes_remove(sender, instance=None, action='', model=None,
pk_set=None, **kwargs):
mapper = {
PRE_ADD: add,
POST_REMOVE: sub
}
if action not in mapper:
"""
监控资产删除节点关系, 或节点删除资产,避免产生游离资产
"""
if action not in ["post_remove", "pre_clear", "post_clear"]:
return
operator = mapper[action]
if reverse:
node: Node = instance
asset_pk_set = set(pk_set)
_update_node_assets_amount(node, asset_pk_set, operator)
if action == "pre_clear":
if model == Node:
instance._nodes = list(instance.nodes.all())
else:
instance._assets = list(instance.assets.all())
return
logger.debug("Assets node remove signal recv: {}".format(action))
if action == "post_remove":
queryset = model.objects.filter(pk__in=pk_set)
else:
asset_pk = instance.id
# 与资产直接关联的节点
node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True))
_update_nodes_asset_amount(node_keys, asset_pk, operator)
if model == Node:
queryset = instance._nodes
else:
queryset = instance._assets
if model == Node:
assets = [instance]
else:
assets = queryset
if isinstance(assets, list):
assets_not_has_node = []
for asset in assets:
if asset.nodes.all().count() == 0:
assets_not_has_node.append(asset.id)
else:
assets_not_has_node = assets.annotate(nodes_count=Count('nodes'))\
.filter(nodes_count=0).values_list('id', flat=True)
Node.org_root().assets.add(*tuple(assets_not_has_node))
RELATED_NODE_IDS = '_related_node_ids'
@receiver(pre_delete, sender=Asset)
def on_asset_delete(instance: Asset, using, **kwargs):
node_ids = set(Node.objects.filter(
assets=instance
).distinct().values_list('id', flat=True))
setattr(instance, RELATED_NODE_IDS, node_ids)
m2m_changed.send(
sender=Asset.nodes.through, instance=instance, reverse=False,
model=Node, pk_set=node_ids, using=using, action=PRE_REMOVE
)
@receiver(post_delete, sender=Asset)
def on_asset_post_delete(instance: Asset, using, **kwargs):
node_ids = getattr(instance, RELATED_NODE_IDS, None)
if node_ids:
m2m_changed.send(
sender=Asset.nodes.through, instance=instance, reverse=False,
model=Node, pk_set=node_ids, using=using, action=POST_REMOVE
)
@receiver([post_save, post_delete], sender=Node)
def on_node_update_or_created(sender, **kwargs):
# 刷新节点
Node.refresh_nodes()
with tmp_to_root_org():
Node.refresh_nodes()

View File

@@ -9,4 +9,3 @@ from .gather_asset_users import *
from .gather_asset_hardware_info import *
from .push_system_user import *
from .system_user_connectivity import *
from .nodes_amount import *

View File

@@ -85,7 +85,7 @@ def test_asset_user_connectivity_util(asset_user, task_name):
raw, summary = test_user_connectivity(
task_name=task_name, asset=asset_user.asset,
username=asset_user.username, password=asset_user.password,
private_key=asset_user.private_key_file
private_key=asset_user.private_key
)
except Exception as e:
logger.warn("Failed run adhoc {}, {}".format(task_name, e))

View File

@@ -64,7 +64,7 @@ GATHER_ASSET_USERS_TASKS = [
"action": {
"module": "shell",
"args": "users=$(getent passwd | grep -v 'nologin' | "
"grep -v 'shudown' | awk -F: '{ print $1 }');for i in $users;do last -w -F $i -1 | "
"grep -v 'shudown' | awk -F: '{ print $1 }');for i in $users;do last -F $i -1 | "
"head -1 | grep -v '^$' | awk '{ print $1\"@\"$3\"@\"$5,$6,$7,$8 }';done"
}
}

View File

@@ -129,5 +129,5 @@ def update_assets_hardware_info_period():
def update_node_assets_hardware_info_manual(node):
task_name = _("Update node asset hardware information: {}").format(node.name)
assets = node.get_all_assets()
result = update_assets_hardware_info_util(assets, task_name=task_name)
result = update_assets_hardware_info_util.delay(assets, task_name=task_name)
return result

View File

@@ -92,7 +92,7 @@ def add_asset_users(assets, results):
for username, data in users.items():
defaults = {'asset': asset, 'username': username, 'present': True}
if data.get("ip"):
defaults["ip_last_login"] = data["ip"][:32]
defaults["ip_last_login"] = data["ip"]
if data.get("date"):
defaults["date_last_login"] = data["date"]
GatheredUser.objects.update_or_create(

View File

@@ -1,14 +0,0 @@
from celery import shared_task
from assets.utils import check_node_assets_amount
from common.utils import get_logger
from common.utils.timezone import now
logger = get_logger(__file__)
@shared_task()
def check_node_assets_amount_celery_task():
logger.info(f'>>> {now()} begin check_node_assets_amount_celery_task ...')
check_node_assets_amount()
logger.info(f'>>> {now()} end check_node_assets_amount_celery_task ...')

View File

@@ -2,13 +2,10 @@
from itertools import groupby
from celery import shared_task
from common.db.utils import get_object_if_need, get_objects_if_need, get_objects
from django.utils.translation import ugettext as _
from django.db.models import Empty
from common.utils import encrypt_password, get_logger
from assets.models import SystemUser, Asset
from orgs.utils import org_aware_func
from orgs.utils import tmp_to_org, org_aware_func
from . import const
from .utils import clean_ansible_task_hosts, group_asset_by_platform
@@ -20,44 +17,20 @@ __all__ = [
]
def _split_by_comma(raw: str):
try:
return [i.strip() for i in raw.split(',')]
except AttributeError:
return []
def _dump_args(args: dict):
return ' '.join([f'{k}={v}' for k, v in args.items() if v is not Empty])
def get_push_unixlike_system_user_tasks(system_user, username=None):
if username is None:
username = system_user.username
password = system_user.password
public_key = system_user.public_key
comment = system_user.name
groups = _split_by_comma(system_user.system_groups)
if groups:
groups = '"%s"' % ','.join(groups)
add_user_args = {
'name': username,
'shell': system_user.shell or Empty,
'state': 'present',
'home': system_user.home or Empty,
'groups': groups or Empty,
'comment': comment
}
tasks = [
{
'name': 'Add user {}'.format(username),
'action': {
'module': 'user',
'args': _dump_args(add_user_args),
'args': 'name={} shell={} state=present'.format(
username, system_user.shell or '/bin/bash',
),
}
},
{
@@ -66,27 +39,24 @@ def get_push_unixlike_system_user_tasks(system_user, username=None):
'module': 'group',
'args': 'name={} state=present'.format(username),
}
},
{
'name': 'Check home dir exists',
'action': {
'module': 'stat',
'args': 'path=/home/{}'.format(username)
},
'register': 'home_existed'
},
{
'name': "Set home dir permission",
'action': {
'module': 'file',
'args': "path=/home/{0} owner={0} group={0} mode=700".format(username)
},
'when': 'home_existed.stat.exists == true'
}
]
if not system_user.home:
tasks.extend([
{
'name': 'Check home dir exists',
'action': {
'module': 'stat',
'args': 'path=/home/{}'.format(username)
},
'register': 'home_existed'
},
{
'name': "Set home dir permission",
'action': {
'module': 'file',
'args': "path=/home/{0} owner={0} group={0} mode=700".format(username)
},
'when': 'home_existed.stat.exists == true'
}
])
if password:
tasks.append({
'name': 'Set {} password'.format(username),
@@ -132,14 +102,8 @@ def get_push_windows_system_user_tasks(system_user, username=None):
if username is None:
username = system_user.username
password = system_user.password
groups = {'Users', 'Remote Desktop Users'}
if system_user.system_groups:
groups.update(_split_by_comma(system_user.system_groups))
groups = ','.join(groups)
tasks = []
if not password:
logger.error("Error: no password found")
return tasks
task = {
'name': 'Add user {}'.format(username),
@@ -152,9 +116,9 @@ def get_push_windows_system_user_tasks(system_user, username=None):
'update_password=always '
'password_expired=no '
'password_never_expires=yes '
'groups="{}" '
'groups="Users,Remote Desktop Users" '
'groups_action=add '
''.format(username, username, password, groups),
''.format(username, username, password),
}
}
tasks.append(task)
@@ -215,22 +179,20 @@ def push_system_user_util(system_user, assets, task_name, username=None):
print(_("Start push system user for platform: [{}]").format(platform))
print(_("Hosts count: {}").format(len(_hosts)))
# 如果没有特殊密码设置,就不需要单独推送某台机器了
if not system_user.has_special_auth(username=username):
if not system_user.has_special_auth():
logger.debug("System user not has special auth")
tasks = get_push_system_user_tasks(system_user, platform, username=username)
run_task(tasks, _hosts)
continue
for _host in _hosts:
system_user.load_asset_special_auth(_host, username=username)
system_user.load_asset_special_auth(_host)
tasks = get_push_system_user_tasks(system_user, platform, username=username)
run_task(tasks, [_host])
@shared_task(queue="ansible")
def push_system_user_to_assets_manual(system_user, username=None):
system_user = get_object_if_need(SystemUser, system_user)
assets = system_user.get_related_assets()
task_name = _("Push system users to assets: {}").format(system_user.name)
return push_system_user_util(system_user, assets, task_name=task_name, username=username)
@@ -247,12 +209,12 @@ def push_system_user_a_asset_manual(system_user, asset, username=None):
@shared_task(queue="ansible")
def push_system_user_to_assets(system_user_id, assets_id, username=None):
system_user = SystemUser.objects.get(id=system_user_id)
assets = get_objects(Asset, assets_id)
def push_system_user_to_assets(system_user, assets, username=None):
task_name = _("Push system users to assets: {}").format(system_user.name)
return push_system_user_util(system_user, assets, task_name, username=username)
# @shared_task
# @register_as_period_task(interval=3600)
# @after_app_ready_start

View File

@@ -31,10 +31,7 @@ def test_system_user_connectivity_util(system_user, assets, task_name):
"""
from ops.utils import update_or_create_ansible_task
# hosts = clean_ansible_task_hosts(assets, system_user=system_user)
# TODO: 这里不传递系统用户因为clean_ansible_task_hosts会通过system_user来判断是否可以推送
# 不符合测试可连接性逻辑, 后面需要优化此逻辑
hosts = clean_ansible_task_hosts(assets)
hosts = clean_ansible_task_hosts(assets, system_user=system_user)
if not hosts:
return {}
platform_hosts_map = {}

View File

@@ -1,8 +1,8 @@
# coding:utf-8
from django.urls import path, re_path
from rest_framework_nested import routers
# from rest_framework.routers import DefaultRouter
from rest_framework_bulk.routes import BulkRouter
from django.db.transaction import non_atomic_requests
from common import api as capi
@@ -45,7 +45,6 @@ urlpatterns = [
path('admin-users/<uuid:pk>/assets/', api.AdminUserAssetsListView.as_view(), name='admin-user-assets'),
path('system-users/<uuid:pk>/auth-info/', api.SystemUserAuthInfoApi.as_view(), name='system-user-auth-info'),
path('system-users/<uuid:pk>/assets/', api.SystemUserAssetsListView.as_view(), name='system-user-assets'),
path('system-users/<uuid:pk>/assets/<uuid:aid>/auth-info/', api.SystemUserAssetAuthInfoApi.as_view(), name='system-user-asset-auth-info'),
path('system-users/<uuid:pk>/tasks/', api.SystemUserTaskApi.as_view(), name='system-user-task-create'),
path('system-users/<uuid:pk>/cmd-filter-rules/', api.SystemUserCommandFilterRuleListApi.as_view(), name='system-user-cmd-filter-rule-list'),
@@ -56,9 +55,9 @@ urlpatterns = [
path('nodes/children/', api.NodeChildrenApi.as_view(), name='node-children-2'),
path('nodes/<uuid:pk>/children/add/', api.NodeAddChildrenApi.as_view(), name='node-add-children'),
path('nodes/<uuid:pk>/assets/', api.NodeAssetsApi.as_view(), name='node-assets'),
path('nodes/<uuid:pk>/assets/add/', non_atomic_requests(api.NodeAddAssetsApi.as_view()), name='node-add-assets'),
path('nodes/<uuid:pk>/assets/replace/', non_atomic_requests(api.MoveAssetsToNodeApi.as_view()), name='node-replace-assets'),
path('nodes/<uuid:pk>/assets/remove/', non_atomic_requests(api.NodeRemoveAssetsApi.as_view()), name='node-remove-assets'),
path('nodes/<uuid:pk>/assets/add/', api.NodeAddAssetsApi.as_view(), name='node-add-assets'),
path('nodes/<uuid:pk>/assets/replace/', api.NodeReplaceAssetsApi.as_view(), name='node-replace-assets'),
path('nodes/<uuid:pk>/assets/remove/', api.NodeRemoveAssetsApi.as_view(), name='node-remove-assets'),
path('nodes/<uuid:pk>/tasks/', api.NodeTaskCreateApi.as_view(), name='node-task-create'),
path('gateways/<uuid:pk>/test-connective/', api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'),

View File

@@ -1,52 +1,195 @@
# ~*~ coding: utf-8 ~*~
#
from django.db.models import Q
from treelib import Tree
from treelib.exceptions import NodeIDAbsentError
from collections import defaultdict
from copy import deepcopy
from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none
from common.http import is_true
from common.utils import get_logger, timeit, lazyproperty
from .models import Asset, Node
logger = get_logger(__file__)
def check_node_assets_amount():
for node in Node.objects.all():
assets_amount = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes=node)
).distinct().count()
class TreeService(Tree):
tag_sep = ' / '
if node.assets_amount != assets_amount:
print(f'>>> <Node:{node.key}> wrong assets amount '
f'{node.assets_amount} right is {assets_amount}')
node.assets_amount = assets_amount
node.save()
@staticmethod
@timeit
def get_nodes_assets_map():
nodes_assets_map = defaultdict(set)
asset_node_list = Node.assets.through.objects.values_list(
'asset', 'node__key'
)
for asset_id, key in asset_node_list:
nodes_assets_map[key].add(asset_id)
return nodes_assets_map
@classmethod
@timeit
def new(cls):
from .models import Node
all_nodes = list(Node.objects.all().values("key", "value"))
all_nodes.sort(key=lambda x: len(x["key"].split(":")))
tree = cls()
tree.create_node(tag='', identifier='', data={})
for node in all_nodes:
key = node["key"]
value = node["value"]
parent_key = ":".join(key.split(":")[:-1])
tree.safe_create_node(
tag=value, identifier=key,
parent=parent_key,
)
tree.init_assets()
return tree
def is_asset_exists_in_node(asset_pk, node_key):
return Asset.objects.filter(
id=asset_pk
).filter(
Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key)
).exists()
def init_assets(self):
node_assets_map = self.get_nodes_assets_map()
for node in self.all_nodes_itr():
key = node.identifier
assets = node_assets_map.get(key, set())
data = {"assets": assets, "all_assets": None}
node.data = data
def safe_create_node(self, **kwargs):
parent = kwargs.get("parent")
if not self.contains(parent):
kwargs['parent'] = self.root
self.create_node(**kwargs)
def is_query_node_all_assets(request):
request = request
query_all_arg = request.query_params.get('all', 'true')
show_current_asset_arg = request.query_params.get('show_current_asset')
if show_current_asset_arg is not None:
return not is_true(show_current_asset_arg)
return is_true(query_all_arg)
def all_children_ids(self, nid, with_self=True):
children_ids = self.expand_tree(nid)
if not with_self:
next(children_ids)
return list(children_ids)
def all_children(self, nid, with_self=True, deep=False):
children_ids = self.all_children_ids(nid, with_self=with_self)
return [self.get_node(i, deep=deep) for i in children_ids]
def get_node(request):
node_id = dict_get_any(request.query_params, ['node', 'node_id'])
if not node_id:
return None
def ancestors_ids(self, nid, with_self=True):
ancestor_ids = list(self.rsearch(nid))
ancestor_ids.pop()
if not with_self:
ancestor_ids.pop(0)
return ancestor_ids
if is_uuid(node_id):
node = get_object_or_none(Node, id=node_id)
else:
node = get_object_or_none(Node, key=node_id)
return node
def ancestors(self, nid, with_self=False, deep=False):
ancestor_ids = self.ancestors_ids(nid, with_self=with_self)
ancestors = [self.get_node(i, deep=deep) for i in ancestor_ids]
return ancestors
def get_node_full_tag(self, nid):
ancestors = self.ancestors(nid, with_self=True)
ancestors.reverse()
return self.tag_sep.join([n.tag for n in ancestors])
def get_family(self, nid, deep=False):
ancestors = self.ancestors(nid, with_self=False, deep=deep)
children = self.all_children(nid, with_self=False)
return ancestors + [self[nid]] + children
@staticmethod
def is_parent(child, parent):
parent_id = child.bpointer
return parent_id == parent.identifier
def root_node(self):
return self.get_node(self.root)
def get_node(self, nid, deep=False):
node = super().get_node(nid)
if deep:
node = self.copy_node(node)
node.data = {}
return node
def parent(self, nid, deep=False):
parent = super().parent(nid)
if deep:
parent = self.copy_node(parent)
return parent
@lazyproperty
def invalid_assets(self):
assets = Asset.objects.filter(is_active=False).values_list('id', flat=True)
return assets
def set_assets(self, nid, assets):
node = self.get_node(nid)
if node.data is None:
node.data = {}
node.data["assets"] = assets
def assets(self, nid):
node = self.get_node(nid)
return node.data.get("assets", set())
def valid_assets(self, nid):
return set(self.assets(nid)) - set(self.invalid_assets)
def all_assets(self, nid):
node = self.get_node(nid)
if node.data is None:
node.data = {}
all_assets = node.data.get("all_assets")
if all_assets is not None:
return all_assets
all_assets = set(self.assets(nid))
try:
children = self.children(nid)
except NodeIDAbsentError:
children = []
for child in children:
all_assets.update(self.all_assets(child.identifier))
node.data["all_assets"] = all_assets
return all_assets
def all_valid_assets(self, nid):
return set(self.all_assets(nid)) - set(self.invalid_assets)
def assets_amount(self, nid):
return len(self.all_assets(nid))
def valid_assets_amount(self, nid):
return len(self.all_valid_assets(nid))
@staticmethod
def copy_node(node):
new_node = deepcopy(node)
new_node.fpointer = None
return new_node
def safe_add_ancestors(self, node, ancestors):
# 如果没有祖先节点,那么添加该节点, 父节点是root node
if len(ancestors) == 0:
parent = self.root_node()
else:
parent = ancestors[0]
# 如果当前节点已再树中,则移动当前节点到父节点中
# 这个是由于 当前节点放到了二级节点中
if not self.contains(parent.identifier):
# logger.debug('Add parent: {}'.format(parent.identifier))
self.safe_add_ancestors(parent, ancestors[1:])
if self.contains(node.identifier):
# msg = 'Move node to parent: {} => {}'.format(
# node.identifier, parent.identifier
# )
# logger.debug(msg)
self.move_node(node.identifier, parent.identifier)
else:
# logger.debug('Add node: {}'.format(node.identifier))
self.add_node(node, parent)
#
# def __getstate__(self):
# self.mutex = None
# self.all_nodes_assets_map = {}
# self.nodes_assets_map = {}
# return self.__dict__
# def __setstate__(self, state):
# self.__dict__ = state

View File

@@ -43,7 +43,7 @@ class UserLoginLogViewSet(ListModelMixin, CommonGenericViewSet):
@staticmethod
def get_org_members():
users = current_org.get_members().values_list('username', flat=True)
users = current_org.get_org_members().values_list('username', flat=True)
return users
def get_queryset(self):
@@ -79,7 +79,7 @@ class PasswordChangeLogViewSet(ListModelMixin, CommonGenericViewSet):
ordering = ['-datetime']
def get_queryset(self):
users = current_org.get_members()
users = current_org.get_org_members()
queryset = super().get_queryset().filter(
user__in=[user.__str__() for user in users]
)
@@ -107,7 +107,7 @@ class CommandExecutionViewSet(ListModelMixin, OrgGenericViewSet):
class CommandExecutionHostRelationViewSet(OrgRelationMixin, OrgBulkModelViewSet):
serializer_class = CommandExecutionHostsRelationSerializer
m2m_field = CommandExecution.hosts.field
permission_classes = [IsOrgAdmin | IsOrgAuditor]
permission_classes = (IsOrgAdmin,)
filter_fields = [
'id', 'asset', 'commandexecution'
]

View File

@@ -20,7 +20,7 @@ class CurrentOrgMembersFilter(filters.BaseFilterBackend):
]
def _get_user_list(self):
users = current_org.get_members(exclude=('Auditor',))
users = current_org.get_org_members(exclude=('Auditor',))
return users
def filter_queryset(self, request, queryset, view):

View File

@@ -1,18 +0,0 @@
# Generated by Django 2.2.10 on 2020-06-24 08:54
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('audits', '0008_auto_20200508_2105'),
]
operations = [
migrations.AlterField(
model_name='ftplog',
name='operate',
field=models.CharField(choices=[('Delete', 'Delete'), ('Upload', 'Upload'), ('Download', 'Download'), ('Rmdir', 'Rmdir'), ('Rename', 'Rename'), ('Mkdir', 'Mkdir'), ('Symlink', 'Symlink')], max_length=16, verbose_name='Operate'),
),
]

View File

@@ -1,18 +0,0 @@
# Generated by Django 2.2.13 on 2020-08-11 03:22
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('audits', '0009_auto_20200624_1654'),
]
operations = [
migrations.AlterField(
model_name='operatelog',
name='datetime',
field=models.DateTimeField(auto_now=True, db_index=True, verbose_name='Datetime'),
),
]

View File

@@ -14,30 +14,12 @@ __all__ = [
class FTPLog(OrgModelMixin):
OPERATE_DELETE = 'Delete'
OPERATE_UPLOAD = 'Upload'
OPERATE_DOWNLOAD = 'Download'
OPERATE_RMDIR = 'Rmdir'
OPERATE_RENAME = 'Rename'
OPERATE_MKDIR = 'Mkdir'
OPERATE_SYMLINK = 'Symlink'
OPERATE_CHOICES = (
(OPERATE_DELETE, _('Delete')),
(OPERATE_UPLOAD, _('Upload')),
(OPERATE_DOWNLOAD, _('Download')),
(OPERATE_RMDIR, _('Rmdir')),
(OPERATE_RENAME, _('Rename')),
(OPERATE_MKDIR, _('Mkdir')),
(OPERATE_SYMLINK, _('Symlink'))
)
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
user = models.CharField(max_length=128, verbose_name=_('User'))
remote_addr = models.CharField(max_length=128, verbose_name=_("Remote addr"), blank=True, null=True)
asset = models.CharField(max_length=1024, verbose_name=_("Asset"))
system_user = models.CharField(max_length=128, verbose_name=_("System user"))
operate = models.CharField(max_length=16, verbose_name=_("Operate"), choices=OPERATE_CHOICES)
operate = models.CharField(max_length=16, verbose_name=_("Operate"))
filename = models.CharField(max_length=1024, verbose_name=_("Filename"))
is_success = models.BooleanField(default=True, verbose_name=_("Success"))
date_start = models.DateTimeField(auto_now_add=True, verbose_name=_('Date start'))
@@ -58,7 +40,7 @@ class OperateLog(OrgModelMixin):
resource_type = models.CharField(max_length=64, verbose_name=_("Resource Type"))
resource = models.CharField(max_length=128, verbose_name=_("Resource"))
remote_addr = models.CharField(max_length=128, verbose_name=_("Remote addr"), blank=True, null=True)
datetime = models.DateTimeField(auto_now=True, verbose_name=_('Datetime'), db_index=True)
datetime = models.DateTimeField(auto_now=True, verbose_name=_('Datetime'))
def __str__(self):
return "<{}> {} <{}>".format(self.user, self.action, self.resource)
@@ -124,7 +106,7 @@ class UserLoginLog(models.Model):
Q(username__contains=keyword)
)
if not current_org.is_root():
username_list = current_org.get_members().values_list('username', flat=True)
username_list = current_org.get_org_members().values_list('username', flat=True)
login_logs = login_logs.filter(username__in=username_list)
return login_logs

View File

@@ -12,20 +12,19 @@ from . import models
class FTPLogSerializer(serializers.ModelSerializer):
operate_display = serializers.ReadOnlyField(source='get_operate_display', label=_('Operate for display'))
class Meta:
model = models.FTPLog
fields = (
'id', 'user', 'remote_addr', 'asset', 'system_user', 'org_id',
'operate', 'filename', 'is_success', 'date_start', 'operate_display'
'id', 'user', 'remote_addr', 'asset', 'system_user',
'operate', 'filename', 'is_success', 'date_start'
)
class UserLoginLogSerializer(serializers.ModelSerializer):
type_display = serializers.ReadOnlyField(source='get_type_display', label=_('Type for display'))
status_display = serializers.ReadOnlyField(source='get_status_display', label=_('Status for display'))
mfa_display = serializers.ReadOnlyField(source='get_mfa_display', label=_('MFA for display'))
type_display = serializers.ReadOnlyField(source='get_type_display')
status_display = serializers.ReadOnlyField(source='get_status_display')
mfa_display = serializers.ReadOnlyField(source='get_mfa_display')
class Meta:
model = models.UserLoginLog
@@ -33,9 +32,6 @@ class UserLoginLogSerializer(serializers.ModelSerializer):
'id', 'username', 'type', 'type_display', 'ip', 'city', 'user_agent',
'mfa', 'reason', 'status', 'status_display', 'datetime', 'mfa_display'
)
extra_kwargs = {
"user_agent": {'label': _('User agent')}
}
class OperateLogSerializer(serializers.ModelSerializer):
@@ -43,7 +39,7 @@ class OperateLogSerializer(serializers.ModelSerializer):
model = models.OperateLog
fields = (
'id', 'user', 'action', 'resource_type', 'resource',
'remote_addr', 'datetime', 'org_id'
'remote_addr', 'datetime'
)
@@ -69,7 +65,7 @@ class CommandExecutionSerializer(serializers.ModelSerializer):
fields_mini = ['id']
fields_small = fields_mini + [
'run_as', 'command', 'user', 'is_finished',
'date_start', 'result', 'is_success', 'org_id'
'date_start', 'result', 'is_success'
]
fields = fields_small + ['hosts', 'run_as_display', 'user_display']
extra_kwargs = {
@@ -78,8 +74,6 @@ class CommandExecutionSerializer(serializers.ModelSerializer):
'hosts': {'label': _('Hosts')}, # 外键,会生成 sql。不在 model 上修改
'run_as': {'label': _('Run as')},
'user': {'label': _('User')},
'run_as_display': {'label': _('Run as for display')},
'user_display': {'label': _('User for display')},
}
@classmethod

View File

@@ -2,44 +2,32 @@
#
import datetime
from django.utils import timezone
from django.conf import settings
from celery import shared_task
from ops.celery.decorator import (
register_as_period_task, after_app_shutdown_clean_periodic
)
from ops.celery.decorator import register_as_period_task
from .models import UserLoginLog, OperateLog
from common.utils import get_log_keep_day
@shared_task
@after_app_shutdown_clean_periodic
def clean_login_log_period():
now = timezone.now()
days = get_log_keep_day('LOGIN_LOG_KEEP_DAYS')
expired_day = now - datetime.timedelta(days=days)
UserLoginLog.objects.filter(datetime__lt=expired_day).delete()
@shared_task
@after_app_shutdown_clean_periodic
def clean_operation_log_period():
now = timezone.now()
days = get_log_keep_day('OPERATE_LOG_KEEP_DAYS')
expired_day = now - datetime.timedelta(days=days)
OperateLog.objects.filter(datetime__lt=expired_day).delete()
@shared_task
def clean_ftp_log_period():
now = timezone.now()
days = get_log_keep_day('FTP_LOG_KEEP_DAYS')
expired_day = now - datetime.timedelta(days=days)
OperateLog.objects.filter(datetime__lt=expired_day).delete()
@register_as_period_task(interval=3600*24)
@shared_task
def clean_audits_log_period():
clean_audits_log_period()
clean_operation_log_period()
clean_ftp_log_period()
def clean_login_log_period():
now = timezone.now()
try:
days = int(settings.LOGIN_LOG_KEEP_DAYS)
except ValueError:
days = 90
expired_day = now - datetime.timedelta(days=days)
UserLoginLog.objects.filter(datetime__lt=expired_day).delete()
@register_as_period_task(interval=3600*24)
@shared_task
def clean_operation_log_period():
now = timezone.now()
try:
days = int(settings.LOGIN_LOG_KEEP_DAYS)
except ValueError:
days = 90
expired_day = now - datetime.timedelta(days=days)
OperateLog.objects.filter(datetime__lt=expired_day).delete()

View File

@@ -6,4 +6,3 @@ from .token import *
from .mfa import *
from .access_key import *
from .login_confirm import *
from .sso import *

View File

@@ -34,6 +34,16 @@ class LoginConfirmSettingUpdateApi(UpdateAPIView):
class TicketStatusApi(mixins.AuthMixin, APIView):
permission_classes = ()
def get_ticket(self):
from tickets.models import Ticket
ticket_id = self.request.session.get("auth_ticket_id")
logger.debug('Login confirm ticket id: {}'.format(ticket_id))
if not ticket_id:
ticket = None
else:
ticket = get_object_or_none(Ticket, pk=ticket_id)
return ticket
def get(self, request, *args, **kwargs):
try:
self.check_user_login_confirm()

View File

@@ -1,87 +0,0 @@
from uuid import UUID
from urllib.parse import urlencode
from django.contrib.auth import login
from django.conf import settings
from django.http.response import HttpResponseRedirect
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.request import Request
from common.utils.timezone import utcnow
from common.const.http import POST, GET
from common.drf.api import JmsGenericViewSet
from common.drf.serializers import EmptySerializer
from common.permissions import IsSuperUser
from common.utils import reverse
from users.models import User
from ..serializers import SSOTokenSerializer
from ..models import SSOToken
from ..filters import AuthKeyQueryDeclaration
from ..mixins import AuthMixin
from ..errors import SSOAuthClosed
NEXT_URL = 'next'
AUTH_KEY = 'authkey'
class SSOViewSet(AuthMixin, JmsGenericViewSet):
queryset = SSOToken.objects.all()
serializer_classes = {
'login_url': SSOTokenSerializer,
'login': EmptySerializer
}
@action(methods=[POST], detail=False, permission_classes=[IsSuperUser], url_path='login-url')
def login_url(self, request, *args, **kwargs):
if not settings.AUTH_SSO:
raise SSOAuthClosed()
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
username = serializer.validated_data['username']
user = User.objects.get(username=username)
next_url = serializer.validated_data.get(NEXT_URL)
operator = request.user.username
# TODO `created_by` 和 `created_by` 可以通过 `ThreadLocal` 统一处理
token = SSOToken.objects.create(user=user, created_by=operator, updated_by=operator)
query = {
AUTH_KEY: token.authkey,
NEXT_URL: next_url or ''
}
login_url = '%s?%s' % (reverse('api-auth:sso-login', external=True), urlencode(query))
return Response(data={'login_url': login_url})
@action(methods=[GET], detail=False, filter_backends=[AuthKeyQueryDeclaration], permission_classes=[])
def login(self, request: Request, *args, **kwargs):
"""
此接口违反了 `Restful` 的规范
`GET` 应该是安全的方法,但此接口是不安全的
"""
request.META['HTTP_X_JMS_LOGIN_TYPE'] = 'W'
authkey = request.query_params.get(AUTH_KEY)
next_url = request.query_params.get(NEXT_URL)
if not next_url or not next_url.startswith('/'):
next_url = reverse('index')
try:
authkey = UUID(authkey)
token = SSOToken.objects.get(authkey=authkey, expired=False)
# 先过期,只能访问这一次
token.expired = True
token.save()
except (ValueError, SSOToken.DoesNotExist):
self.send_auth_signal(success=False, reason='authkey_invalid')
return HttpResponseRedirect(next_url)
# 判断是否过期
if (utcnow().timestamp() - token.date_created.timestamp()) > settings.AUTH_SSO_AUTHKEY_TTL:
self.send_auth_signal(success=False, reason='authkey_timeout')
return HttpResponseRedirect(next_url)
user = token.user
login(self.request, user, 'authentication.backends.api.SSOAuthentication')
self.send_auth_signal(success=True, user=user)
return HttpResponseRedirect(next_url)

View File

@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
#
from django.shortcuts import redirect
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework.generics import CreateAPIView
@@ -40,5 +40,3 @@ class TokenCreateApi(AuthMixin, CreateAPIView):
return Response(e.as_data(), status=400)
except errors.NeedMoreInfoError as e:
return Response(e.as_data(), status=200)
except errors.PasswdTooSimple as e:
return redirect(e.url)

View File

@@ -5,13 +5,14 @@ import uuid
import time
from django.core.cache import cache
from django.conf import settings
from django.utils.translation import ugettext as _
from django.utils.six import text_type
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import authentication, exceptions
from common.auth import signature
from rest_framework.authentication import CSRFCheck
from common.utils import get_object_or_none, make_signature, http_to_unixtime
from ..models import AccessKey, PrivateToken
@@ -196,12 +197,3 @@ class SignatureAuthentication(signature.SignatureAuthentication):
return user, secret
except AccessKey.DoesNotExist:
return None, None
class SSOAuthentication(ModelBackend):
"""
什么也不做呀😺
"""
def authenticate(self, request, sso_token=None, **kwargs):
pass

View File

@@ -1,10 +0,0 @@
from django_cas_ng.middleware import CASMiddleware as _CASMiddleware
from django.core.exceptions import MiddlewareNotUsed
from django.conf import settings
class CASMiddleware(_CASMiddleware):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not settings.AUTH_CAS:
raise MiddlewareNotUsed

View File

@@ -1,10 +0,0 @@
from jms_oidc_rp.middleware import OIDCRefreshIDTokenMiddleware as _OIDCRefreshIDTokenMiddleware
from django.core.exceptions import MiddlewareNotUsed
from django.conf import settings
class OIDCRefreshIDTokenMiddleware(_OIDCRefreshIDTokenMiddleware):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not settings.AUTH_OPENID:
raise MiddlewareNotUsed

View File

@@ -1,17 +1,17 @@
# -*- coding: utf-8 -*-
#
import traceback
from django.contrib.auth import get_user_model, authenticate
from django.contrib.auth import get_user_model
from radiusauth.backends import RADIUSBackend, RADIUSRealmBackend
from django.conf import settings
from pyrad.packet import AccessRequest
User = get_user_model()
class CreateUserMixin:
def get_django_user(self, username, password=None, *args, **kwargs):
def get_django_user(self, username, password=None):
if isinstance(username, bytes):
username = username.decode()
try:
@@ -27,23 +27,10 @@ class CreateUserMixin:
user.save()
return user
def _perform_radius_auth(self, client, packet):
# TODO: 等待官方库修复这个BUG
try:
return super()._perform_radius_auth(client, packet)
except UnicodeError as e:
import sys
tb = ''.join(traceback.format_exception(*sys.exc_info(), limit=2, chain=False))
if tb.find("cl.decode") != -1:
return [], False, False
return None
class RadiusBackend(CreateUserMixin, RADIUSBackend):
def authenticate(self, request, username='', password='', **kwargs):
return super().authenticate(request, username=username, password=password)
pass
class RadiusRealmBackend(CreateUserMixin, RADIUSRealmBackend):
def authenticate(self, request, username='', password='', realm=None, **kwargs):
return super().authenticate(request, username=username, password=password, realm=realm)
pass

View File

@@ -1,2 +0,0 @@
RSA_PRIVATE_KEY = 'rsa_private_key'
RSA_PUBLIC_KEY = 'rsa_public_key'

View File

@@ -4,14 +4,12 @@ from django.utils.translation import ugettext_lazy as _
from django.urls import reverse
from django.conf import settings
from common.exceptions import JMSException
from .signals import post_auth_failed
from users.utils import (
increase_login_failed_count, get_login_failed_count
)
reason_password_failed = 'password_failed'
reason_password_decrypt_failed = 'password_decrypt_failed'
reason_mfa_failed = 'mfa_failed'
reason_mfa_unset = 'mfa_unset'
reason_user_not_exist = 'user_not_exist'
@@ -21,7 +19,6 @@ reason_user_inactive = 'user_inactive'
reason_choices = {
reason_password_failed: _('Username/password check failed'),
reason_password_decrypt_failed: _('Password decrypt failed'),
reason_mfa_failed: _('MFA failed'),
reason_mfa_unset: _('MFA unset'),
reason_user_not_exist: _("Username does not exist"),
@@ -206,17 +203,3 @@ class LoginConfirmOtherError(LoginConfirmBaseError):
def __init__(self, ticket_id, status):
msg = login_confirm_error_msg.format(status)
super().__init__(ticket_id=ticket_id, msg=msg)
class SSOAuthClosed(JMSException):
default_code = 'sso_auth_closed'
default_detail = _('SSO auth closed')
class PasswdTooSimple(JMSException):
default_code = 'passwd_too_simple'
default_detail = _('Your password is too simple, please change it for security')
def __init__(self, url, *args, **kwargs):
super(PasswdTooSimple, self).__init__(*args, **kwargs)
self.url = url

View File

@@ -1,15 +0,0 @@
from rest_framework import filters
from rest_framework.compat import coreapi, coreschema
class AuthKeyQueryDeclaration(filters.BaseFilterBackend):
def get_schema_fields(self, view):
return [
coreapi.Field(
name='authkey', location='query', required=True, type='string',
schema=coreschema.String(
title='authkey',
description='authkey'
)
)
]

View File

@@ -2,7 +2,6 @@
#
from django import forms
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from captcha.fields import CaptchaField
@@ -11,7 +10,7 @@ class UserLoginForm(forms.Form):
username = forms.CharField(label=_('Username'), max_length=100)
password = forms.CharField(
label=_('Password'), widget=forms.PasswordInput,
max_length=1024, strip=False
max_length=128, strip=False
)
def confirm_login_allowed(self, user):
@@ -22,24 +21,9 @@ class UserLoginForm(forms.Form):
)
class UserCheckOtpCodeForm(forms.Form):
otp_code = forms.CharField(label=_('MFA code'), max_length=6)
class CaptchaMixin(forms.Form):
class UserLoginCaptchaForm(UserLoginForm):
captcha = CaptchaField()
class ChallengeMixin(forms.Form):
challenge = forms.CharField(label=_('MFA code'), max_length=6,
required=False)
def get_user_login_form_cls(*, captcha=False):
bases = []
if settings.SECURITY_LOGIN_CAPTCHA_ENABLED and captcha:
bases.append(CaptchaMixin)
if settings.SECURITY_LOGIN_CHALLENGE_ENABLED:
bases.append(ChallengeMixin)
bases.append(UserLoginForm)
return type('UserLoginForm', tuple(bases), {})
class UserCheckOtpCodeForm(forms.Form):
otp_code = forms.CharField(label=_('MFA code'), max_length=6)

View File

@@ -1,32 +0,0 @@
# Generated by Django 2.2.10 on 2020-07-31 08:36
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import uuid
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('authentication', '0003_loginconfirmsetting'),
]
operations = [
migrations.CreateModel(
name='SSOToken',
fields=[
('created_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Created by')),
('updated_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Updated by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('authkey', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False, verbose_name='Token')),
('expired', models.BooleanField(default=False, verbose_name='Expired')),
('user', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.PROTECT, to=settings.AUTH_USER_MODEL, verbose_name='User')),
],
options={
'abstract': False,
},
),
]

View File

@@ -1,13 +1,7 @@
# -*- coding: utf-8 -*-
#
from urllib.parse import urlencode
from functools import partial
import time
from django.conf import settings
from django.contrib.auth import authenticate
from django.shortcuts import reverse
from django.contrib.auth import BACKEND_SESSION_KEY
from common.utils import get_object_or_none, get_request_ip, get_logger
from users.models import User
@@ -15,9 +9,8 @@ from users.utils import (
is_block_login, clean_failed_count
)
from . import errors
from .utils import rsa_decrypt
from .utils import check_user_valid
from .signals import post_auth_success, post_auth_failed
from .const import RSA_PRIVATE_KEY
logger = get_logger(__name__)
@@ -28,14 +21,8 @@ class AuthMixin:
def get_user_from_session(self):
if self.request.session.is_empty():
raise errors.SessionEmptyError()
if all((self.request.user,
not self.request.user.is_anonymous,
BACKEND_SESSION_KEY in self.request.session)):
user = self.request.user
user.backend = self.request.session[BACKEND_SESSION_KEY]
return user
if self.request.user and not self.request.user.is_anonymous:
return self.request.user
user_id = self.request.session.get('user_id')
if not user_id:
user = None
@@ -53,7 +40,7 @@ class AuthMixin:
ip = ip or get_request_ip(self.request)
return ip
def check_is_block(self, raise_exception=True):
def check_is_block(self):
if hasattr(self.request, 'data'):
username = self.request.data.get("username")
else:
@@ -61,60 +48,27 @@ class AuthMixin:
ip = self.get_request_ip()
if is_block_login(username, ip):
logger.warn('Ip was blocked' + ': ' + username + ':' + ip)
exception = errors.BlockLoginError(username=username, ip=ip)
if raise_exception:
raise errors.BlockLoginError(username=username, ip=ip)
else:
return exception
raise errors.BlockLoginError(username=username, ip=ip)
def decrypt_passwd(self, raw_passwd):
# 获取解密密钥,对密码进行解密
rsa_private_key = self.request.session.get(RSA_PRIVATE_KEY)
if rsa_private_key is not None:
try:
return rsa_decrypt(raw_passwd, rsa_private_key)
except Exception as e:
logger.error(e, exc_info=True)
logger.error(f'Decrypt password faild: password[{raw_passwd}] rsa_private_key[{rsa_private_key}]')
return None
return raw_passwd
def check_user_auth(self, decrypt_passwd=False):
def check_user_auth(self):
self.check_is_block()
request = self.request
if hasattr(request, 'data'):
data = request.data
username = request.data.get('username', '')
password = request.data.get('password', '')
public_key = request.data.get('public_key', '')
else:
data = request.POST
username = data.get('username', '')
password = data.get('password', '')
challenge = data.get('challenge', '')
public_key = data.get('public_key', '')
username = request.POST.get('username', '')
password = request.POST.get('password', '')
public_key = request.POST.get('public_key', '')
user, error = check_user_valid(
request=request, username=username, password=password, public_key=public_key
)
ip = self.get_request_ip()
CredentialError = partial(errors.CredentialError, username=username, ip=ip, request=request)
if decrypt_passwd:
password = self.decrypt_passwd(password)
if not password:
raise CredentialError(error=errors.reason_password_decrypt_failed)
user = authenticate(request,
username=username,
password=password + challenge.strip(),
public_key=public_key)
if not user:
raise CredentialError(error=errors.reason_password_failed)
elif user.is_expired:
raise CredentialError(error=errors.reason_user_inactive)
elif not user.is_active:
raise CredentialError(error=errors.reason_user_inactive)
elif user.password_has_expired:
raise CredentialError(error=errors.reason_password_expired)
self._check_passwd_is_too_simple(user, password)
raise errors.CredentialError(
username=username, error=error, ip=ip, request=request
)
clean_failed_count(username, ip)
request.session['auth_password'] = 1
request.session['user_id'] = str(user.id)
@@ -122,30 +76,14 @@ class AuthMixin:
request.session['auth_backend'] = auth_backend
return user
@classmethod
def _check_passwd_is_too_simple(cls, user, password):
if user.is_superuser and password == 'admin':
reset_passwd_url = reverse('authentication:reset-password')
query_str = urlencode({
'token': user.generate_reset_token()
})
reset_passwd_url = f'{reset_passwd_url}?{query_str}'
flash_page_url = reverse('authentication:passwd-too-simple-flash-msg')
query_str = urlencode({
'redirect_url': reset_passwd_url
})
raise errors.PasswdTooSimple(f'{flash_page_url}?{query_str}')
def check_user_auth_if_need(self, decrypt_passwd=False):
def check_user_auth_if_need(self):
request = self.request
if request.session.get('auth_password') and \
request.session.get('user_id'):
user = self.get_user_from_session()
if user:
return user
return self.check_user_auth(decrypt_passwd=decrypt_passwd)
return self.check_user_auth()
def check_user_mfa_if_need(self, user):
if self.request.session.get('auth_mfa'):
@@ -174,12 +112,12 @@ class AuthMixin:
if not ticket_id:
ticket = None
else:
ticket = Ticket.origin_objects.get(pk=ticket_id)
ticket = get_object_or_none(Ticket, pk=ticket_id)
return ticket
def get_ticket_or_create(self, confirm_setting):
ticket = self.get_ticket()
if not ticket or ticket.status == ticket.STATUS.CLOSED:
if not ticket or ticket.status == ticket.STATUS_CLOSED:
ticket = confirm_setting.create_confirm_ticket(self.request)
self.request.session['auth_ticket_id'] = str(ticket.id)
return ticket
@@ -188,12 +126,12 @@ class AuthMixin:
ticket = self.get_ticket()
if not ticket:
raise errors.LoginConfirmOtherError('', "Not found")
if ticket.status == ticket.STATUS.OPEN:
if ticket.status == ticket.STATUS_OPEN:
raise errors.LoginConfirmWaitError(ticket.id)
elif ticket.action == ticket.ACTION.APPROVE:
elif ticket.action == ticket.ACTION_APPROVE:
self.request.session["auth_confirm"] = "1"
return
elif ticket.action == ticket.ACTION.REJECT:
elif ticket.action == ticket.ACTION_REJECT:
raise errors.LoginConfirmOtherError(
ticket.id, ticket.get_action_display()
)

View File

@@ -1,13 +1,10 @@
import uuid
from functools import partial
from django.db import models
from django.utils import timezone
from django.utils.translation import ugettext_lazy as _, ugettext as __
from rest_framework.authtoken.models import Token
from django.conf import settings
from django.utils.crypto import get_random_string
from common.db import models
from common.mixins.models import CommonModelMixin
from common.utils import get_object_or_none, get_request_ip, get_ip_city
@@ -53,7 +50,7 @@ class LoginConfirmSetting(CommonModelMixin):
def create_confirm_ticket(self, request=None):
from tickets.models import Ticket
title = _('Login confirm') + ' {}'.format(self.user)
title = _('Login confirm') + '{}'.format(self.user)
if request:
remote_addr = get_request_ip(request)
city = get_ip_city(remote_addr)
@@ -71,7 +68,7 @@ class LoginConfirmSetting(CommonModelMixin):
reviewer = self.reviewers.all()
ticket = Ticket.objects.create(
user=self.user, title=title, body=body,
type=Ticket.TYPE.LOGIN_CONFIRM,
type=Ticket.TYPE_LOGIN_CONFIRM,
)
ticket.assignees.set(reviewer)
return ticket
@@ -79,12 +76,3 @@ class LoginConfirmSetting(CommonModelMixin):
def __str__(self):
return '{} confirm'.format(self.user.username)
class SSOToken(models.JMSBaseModel):
"""
类似腾讯企业邮的 [单点登录](https://exmail.qq.com/qy_mng_logic/doc#10036)
出于安全考虑,这里的 `token` 使用一次随即过期。但我们保留每一个生成过的 `token`。
"""
authkey = models.UUIDField(primary_key=True, default=uuid.uuid4, verbose_name=_('Token'))
expired = models.BooleanField(default=False, verbose_name=_('Expired'))
user = models.ForeignKey('users.User', on_delete=models.PROTECT, verbose_name=_('User'), db_constraint=False)

View File

@@ -5,12 +5,12 @@ from rest_framework import serializers
from common.utils import get_object_or_none
from users.models import User
from users.serializers import UserProfileSerializer
from .models import AccessKey, LoginConfirmSetting, SSOToken
from .models import AccessKey, LoginConfirmSetting
__all__ = [
'AccessKeySerializer', 'OtpVerifySerializer', 'BearerTokenSerializer',
'MFAChallengeSerializer', 'LoginConfirmSettingSerializer', 'SSOTokenSerializer',
'MFAChallengeSerializer', 'LoginConfirmSettingSerializer',
]
@@ -76,9 +76,3 @@ class LoginConfirmSettingSerializer(serializers.ModelSerializer):
model = LoginConfirmSetting
fields = ['id', 'user', 'reviewers', 'date_created', 'date_updated']
read_only_fields = ['date_created', 'date_updated']
class SSOTokenSerializer(serializers.Serializer):
username = serializers.CharField(write_only=True)
login_url = serializers.CharField(read_only=True)
next = serializers.CharField(write_only=True, allow_blank=True, required=False, allow_null=True)

View File

@@ -1,27 +1,10 @@
from importlib import import_module
from django.conf import settings
from django.contrib.auth import user_logged_in
from django.core.cache import cache
from django.dispatch import receiver
from django_cas_ng.signals import cas_user_authenticated
from jms_oidc_rp.signals import openid_user_login_failed, openid_user_login_success
from .signals import post_auth_success, post_auth_failed
@receiver(user_logged_in)
def on_user_auth_login_success(sender, user, request, **kwargs):
if settings.USER_LOGIN_SINGLE_MACHINE_ENABLED:
user_id = 'single_machine_login_' + str(user.id)
session_key = cache.get(user_id)
if session_key and session_key != request.session.session_key:
session = import_module(settings.SESSION_ENGINE).SessionStore(session_key)
session.delete()
cache.set(user_id, request.session.session_key, None)
@receiver(openid_user_login_success)
def on_oidc_user_login_success(sender, request, user, **kwargs):
post_auth_success.send(sender, user=user, request=request)
@@ -30,8 +13,3 @@ def on_oidc_user_login_success(sender, request, user, **kwargs):
@receiver(openid_user_login_failed)
def on_oidc_user_login_failed(sender, username, request, reason, **kwargs):
post_auth_failed.send(sender, username=username, request=request, reason=reason)
@receiver(cas_user_authenticated)
def on_cas_user_login_success(sender, request, user, **kwargs):
post_auth_success.send(sender, user=user, request=request)

Some files were not shown because too many files have changed in this diff Show More