Compare commits

..

15 Commits
v2.4 ... v2.0.2

Author SHA1 Message Date
ibuler
7c08582e7d Merge branch 'master' into v2.0 2020-07-08 16:14:19 +08:00
ibuler
3555fe1c20 Merge branch 'v2.0' of github.com:jumpserver/jumpserver into v2.0 2020-07-08 16:14:06 +08:00
ibuler
4a1045ba81 ci(build&release): 自动化构建发布 2020-07-08 15:58:57 +08:00
ibuler
7fc3218010 添加example api 2020-07-08 15:58:57 +08:00
ibuler
0dd4c8adc2 Merge branch 'v2.0' of github.com:jumpserver/jumpserver into v2.0 2020-07-08 15:48:05 +08:00
xinwen
a9626c2b39 Merge pull request #4223 from jumpserver/update-add-filter
[Update] assets/gathered_user 添加过滤字段
2020-07-01 17:32:58 +08:00
xinwen
b265fad50f [Update] assets/gathered_user 添加过滤字段 2020-07-01 17:21:00 +08:00
BaiJiangJie
f40b50dddd Merge pull request #4209 from jumpserver/ftp_log_order
fix:修改 ftp 日志按开始日期排序
2020-07-01 15:02:44 +08:00
xinwen
0d6255b07f Merge pull request #4218 from jumpserver/system-user-add-filter
System user add filter
2020-07-01 15:01:07 +08:00
xinwen
06b02cfcfd [Update] 系统用户添加过滤字段 2020-07-01 14:57:52 +08:00
jym503558564
a7e1ed6f03 fix:修改 ftp 日志按开始日期排序 2020-07-01 11:04:42 +08:00
xinwen
1c6f89519a Merge pull request #4201 from jumpserver/fix-i18n
Fix i18n
2020-06-30 14:21:44 +08:00
xinwen
24d093747f [Fix] X-Pack/云管中心 i18n 2020-06-30 14:18:21 +08:00
ibuler
7548bb8976 Merge branch 'dev' of github.com:jumpserver/jumpserver into dev 2020-06-17 11:33:57 +08:00
ibuler
8ffa5e0aec 添加example api 2020-06-16 20:11:50 +08:00
264 changed files with 2912 additions and 8314 deletions

17
.github/ISSUE_TEMPLATE.md vendored Normal file
View File

@@ -0,0 +1,17 @@
[简述你的问题]
##### 使用版本
[请提供你使用的JumpServer版本 如 2.0.1 注: 1.4及以下版本不再提供支持]
##### 问题复现步骤
1. [步骤1]
2. [步骤2]
##### 具体表现[截图可能会更好些,最好能截全]
##### 其他
[注:] 完成后请关闭 issue

View File

@@ -1,9 +0,0 @@
---
name: 需求建议
about: 提出针对本项目的想法和建议
title: "[Feature] "
labels: 待处理, 需求
assignees: 'ibuler'
---
**请描述您的需求或者改进建议.**

View File

@@ -1,22 +0,0 @@
---
name: Bug 提交
about: 提交产品缺陷帮助我们更好的改进
title: "[Bug] "
labels: bug, 待处理
assignees: wojiushixiaobai
---
**JumpServer 版本(v1.5.9以下不再支持)**
**浏览器版本**
**Bug 描述**
**Bug 重现步骤(有截图更好)**
1.
2.
3.

View File

@@ -1,10 +0,0 @@
---
name: 问题咨询
about: 提出针对本项目安装部署、使用及其他方面的相关问题
title: "[Question] "
labels: 提问, 待处理
assignees: wojiushixiaobai
---
**请描述您的问题.**

View File

@@ -1,12 +0,0 @@
on: [push, pull_request, release]
name: JumpServer repos generic handler
jobs:
generic_handler:
name: Run generic handler
runs-on: ubuntu-latest
steps:
- uses: jumpserver/action-generic-handler@master
env:
GITHUB_TOKEN: ${{ secrets.PRIVATE_TOKEN }}

View File

@@ -1,39 +1,19 @@
FROM registry.fit2cloud.com/public/python:v3 as stage-build
MAINTAINER Jumpserver Team <ibuler@qq.com>
ARG VERSION
ENV VERSION=$VERSION
WORKDIR /opt/jumpserver
ADD . .
RUN cd utils && bash -ixeu build.sh
FROM registry.fit2cloud.com/public/python:v3
ARG PIP_MIRROR=https://pypi.douban.com/simple
ENV PIP_MIRROR=$PIP_MIRROR
ARG MYSQL_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/mysql/yum/mysql57-community-el6/
ENV MYSQL_MIRROR=$MYSQL_MIRROR
MAINTAINER Jumpserver Team <ibuler@qq.com>
WORKDIR /opt/jumpserver
COPY ./requirements ./requirements
RUN useradd jumpserver
RUN wget -O /etc/yum.repos.d/CentOS-Base.repo http://mirrors.aliyun.com/repo/Centos-6.repo \
&& sed -i 's@/centos/@/centos-vault/@g' /etc/yum.repos.d/CentOS-Base.repo \
&& sed -i 's@$releasever@6.10@g' /etc/yum.repos.d/CentOS-Base.repo
COPY ./requirements /tmp/requirements
RUN yum -y install epel-release && \
echo -e "[mysql]\nname=mysql\nbaseurl=${MYSQL_MIRROR}\ngpgcheck=0\nenabled=1" > /etc/yum.repos.d/mysql.repo
RUN yum -y install $(cat requirements/rpm_requirements.txt)
RUN pip install --upgrade pip setuptools==49.6.0 wheel -i ${PIP_MIRROR} && \
pip config set global.index-url ${PIP_MIRROR}
RUN pip install $(grep 'jms' requirements/requirements.txt) -i https://pypi.org/simple
RUN pip install -r requirements/requirements.txt
COPY --from=stage-build /opt/jumpserver/release/jumpserver /opt/jumpserver
echo -e "[mysql]\nname=mysql\nbaseurl=https://mirrors.tuna.tsinghua.edu.cn/mysql/yum/mysql57-community-el6/\ngpgcheck=0\nenabled=1" > /etc/yum.repos.d/mysql.repo
RUN cd /tmp/requirements && yum -y install $(cat rpm_requirements.txt)
RUN cd /tmp/requirements && pip install --upgrade pip setuptools && pip install wheel && \
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt || pip install -r requirements.txt
RUN mkdir -p /root/.ssh/ && echo -e "Host *\n\tStrictHostKeyChecking no\n\tUserKnownHostsFile /dev/null" > /root/.ssh/config
COPY . /opt/jumpserver
RUN echo > config.yml
VOLUME /opt/jumpserver/data
VOLUME /opt/jumpserver/logs

View File

@@ -4,10 +4,6 @@
[![Django](https://img.shields.io/badge/django-2.2-brightgreen.svg?style=plastic)](https://www.djangoproject.com/)
[![Docker Pulls](https://img.shields.io/docker/pulls/jumpserver/jms_all.svg)](https://hub.docker.com/u/jumpserver)
|Developer Wanted|
|------------------|
|JumpServer 正在寻找开发者,一起为改变世界做些贡献吧,哪怕一点点,联系我 <ibuler@fit2cloud.com> |
JumpServer 是全球首款开源的堡垒机,使用 GNU GPL v2.0 开源协议,是符合 4A 规范的运维安全审计系统。
JumpServer 使用 Python / Django 为主进行开发,遵循 Web 2.0 规范,配备了业界领先的 Web Terminal 方案,交互界面美观、用户体验好。
@@ -20,8 +16,8 @@ JumpServer 采纳分布式架构,支持多机房跨区域部署,支持横向
## 特色优势
- 开源: 零门槛,线上快速获取和安装;
- 分布式: 轻松支持大规模并发访问;
- 开源: 零门槛,线上快速获取和安装, 修复版本视情况而定
, 修复版本视情况而定- 分布式: 轻松支持大规模并发访问;
- 无插件: 仅需浏览器,极致的 Web Terminal 使用体验;
- 多云支持: 一套系统,同时管理不同云上面的资产;
- 云端存储: 审计录像云端存储,永不丢失;
@@ -206,16 +202,6 @@ v2.1.0 是 v2.0.0 之后的功能版本。
- [完整文档](https://docs.jumpserver.org)
- [演示视频](https://jumpserver.oss-cn-hangzhou.aliyuncs.com/jms-media/%E3%80%90%E6%BC%94%E7%A4%BA%E8%A7%86%E9%A2%91%E3%80%91Jumpserver%20%E5%A0%A1%E5%9E%92%E6%9C%BA%20V1.5.0%20%E6%BC%94%E7%A4%BA%E8%A7%86%E9%A2%91%20-%20final.mp4)
## 组件项目
- [Lina](https://github.com/jumpserver/lina) JumpServer Web UI 项目
- [Luna](https://github.com/jumpserver/luna) JumpServer Web Terminal 项目
- [Koko](https://github.com/jumpserver/koko) JumpServer 字符协议 Connector 项目,替代原来 Python 版本的 [Coco](https://github.com/jumpserver/coco)
- [Guacamole](https://github.com/jumpserver/docker-guacamole) JumpServer 图形协议 Connector 项目,依赖 [Apache Guacamole](https://guacamole.apache.org/)
## JumpServer 企业版
- [申请企业版试用](https://jinshuju.net/f/kyOYpi)
> 注:企业版支持离线安装,申请通过后会提供高速下载链接。
## 案例研究
- [JumpServer 堡垒机护航顺丰科技超大规模资产安全运维](https://blog.fit2cloud.com/?p=1147)

View File

@@ -1,3 +1,2 @@
from .remote_app import *
from .database_app import *
from .k8s_app import *

View File

@@ -1,20 +0,0 @@
# coding: utf-8
#
from orgs.mixins.api import OrgBulkModelViewSet
from .. import models
from .. import serializers
from ..hands import IsOrgAdminOrAppUser
__all__ = [
'K8sAppViewSet',
]
class K8sAppViewSet(OrgBulkModelViewSet):
model = models.K8sApp
filter_fields = ('name',)
search_fields = filter_fields
permission_classes = (IsOrgAdminOrAppUser,)
serializer_class = serializers.K8sAppSerializer

View File

@@ -23,7 +23,6 @@ REMOTE_APP_TYPE_CHROME_FIELDS = [
REMOTE_APP_TYPE_MYSQL_WORKBENCH_FIELDS = [
{'name': 'mysql_workbench_ip'},
{'name': 'mysql_workbench_name'},
{'name': 'mysql_workbench_port'},
{'name': 'mysql_workbench_username'},
{'name': 'mysql_workbench_password', 'write_only': True}
]

View File

@@ -1,34 +0,0 @@
# Generated by Django 2.2.13 on 2020-08-07 07:13
from django.db import migrations, models
import uuid
class Migration(migrations.Migration):
dependencies = [
('applications', '0004_auto_20191218_1705'),
]
operations = [
migrations.CreateModel(
name='K8sApp',
fields=[
('created_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Created by')),
('updated_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Updated by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('org_id', models.CharField(blank=True, db_index=True, default='', max_length=36, verbose_name='Organization')),
('name', models.CharField(max_length=128, verbose_name='Name')),
('type', models.CharField(choices=[('k8s', 'Kubernetes')], default='k8s', max_length=128, verbose_name='Type')),
('cluster', models.CharField(max_length=1024, verbose_name='Cluster')),
('comment', models.TextField(blank=True, default='', max_length=128, verbose_name='Comment')),
],
options={
'verbose_name': 'KubernetesApp',
'ordering': ('name',),
'unique_together': {('org_id', 'name')},
},
),
]

View File

@@ -1,3 +1,2 @@
from .remote_app import *
from .database_app import *
from .k8s_app import *

View File

@@ -1,27 +0,0 @@
from django.utils.translation import gettext_lazy as _
from common.db import models
from orgs.mixins.models import OrgModelMixin
class K8sApp(OrgModelMixin, models.JMSModel):
class TYPE(models.ChoiceSet):
K8S = 'k8s', _('Kubernetes')
name = models.CharField(max_length=128, verbose_name=_('Name'))
type = models.CharField(
default=TYPE.K8S, choices=TYPE.choices,
max_length=128, verbose_name=_('Type')
)
cluster = models.CharField(max_length=1024, verbose_name=_('Cluster'))
comment = models.TextField(
max_length=128, default='', blank=True, verbose_name=_('Comment')
)
def __str__(self):
return self.name
class Meta:
unique_together = [('org_id', 'name'), ]
verbose_name = _('KubernetesApp')
ordering = ('name', )

View File

@@ -1,3 +1,2 @@
from .remote_app import *
from .database_app import *
from .k8s_app import *

View File

@@ -1,22 +0,0 @@
from rest_framework import serializers
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .. import models
__all__ = [
'K8sAppSerializer',
]
class K8sAppSerializer(BulkOrgResourceModelSerializer):
type_display = serializers.CharField(source='get_type_display', read_only=True)
class Meta:
model = models.K8sApp
fields = [
'id', 'name', 'type', 'type_display', 'comment', 'created_by',
'date_created', 'date_updated', 'cluster'
]
read_only_fields = [
'id', 'created_by', 'date_created', 'date_updated',
]

View File

@@ -12,7 +12,6 @@ app_name = 'applications'
router = BulkRouter()
router.register(r'remote-apps', api.RemoteAppViewSet, 'remote-app')
router.register(r'database-apps', api.DatabaseAppViewSet, 'database-app')
router.register(r'k8s-apps', api.K8sAppViewSet, 'k8s-app')
urlpatterns = [
path('remote-apps/<uuid:pk>/connection-info/', api.RemoteAppConnectionInfoApi.as_view(), name='remote-app-connection-info'),

View File

@@ -1,4 +1,3 @@
from .mixin import *
from .admin_user import *
from .asset import *
from .label import *

View File

@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
#
from assets.api import FilterAssetByNodeMixin
from rest_framework.viewsets import ModelViewSet
from rest_framework.generics import RetrieveAPIView
from django.shortcuts import get_object_or_404
@@ -14,7 +14,7 @@ from .. import serializers
from ..tasks import (
update_asset_hardware_info_manual, test_asset_connectivity_manual
)
from ..filters import FilterAssetByNodeFilterBackend, LabelFilterBackend, IpInFilterBackend
from ..filters import AssetByNodeFilterBackend, LabelFilterBackend
logger = get_logger(__file__)
@@ -25,14 +25,14 @@ __all__ = [
]
class AssetViewSet(FilterAssetByNodeMixin, OrgBulkModelViewSet):
class AssetViewSet(OrgBulkModelViewSet):
"""
API endpoint that allows Asset to be viewed or edited.
"""
model = Asset
filter_fields = (
"hostname", "ip", "systemuser__id", "admin_user__id", "platform__base",
"is_active", 'ip'
"is_active"
)
search_fields = ("hostname", "ip")
ordering_fields = ("hostname", "ip", "port", "cpu_cores")
@@ -41,7 +41,7 @@ class AssetViewSet(FilterAssetByNodeMixin, OrgBulkModelViewSet):
'display': serializers.AssetDisplaySerializer,
}
permission_classes = (IsOrgAdminOrAppUser,)
extra_filter_backends = [FilterAssetByNodeFilterBackend, LabelFilterBackend, IpInFilterBackend]
extra_filter_backends = [AssetByNodeFilterBackend, LabelFilterBackend]
def set_assets_node(self, assets):
if not isinstance(assets, list):

View File

@@ -1,89 +0,0 @@
from typing import List
from assets.models import Node, Asset
from assets.pagination import AssetLimitOffsetPagination
from common.utils import lazyproperty, dict_get_any, is_uuid, get_object_or_none
from assets.utils import get_node, is_query_node_all_assets
class SerializeToTreeNodeMixin:
permission_classes = ()
def serialize_nodes(self, nodes: List[Node], with_asset_amount=False):
if with_asset_amount:
def _name(node: Node):
return '{} ({})'.format(node.value, node.assets_amount)
else:
def _name(node: Node):
return node.value
data = [
{
'id': node.key,
'name': _name(node),
'title': _name(node),
'pId': node.parent_key,
'isParent': True,
'open': node.is_org_root(),
'meta': {
'node': {
"id": node.id,
"key": node.key,
"value": node.value,
},
'type': 'node'
}
}
for node in nodes
]
return data
def get_platform(self, asset: Asset):
default = 'file'
icon = {'windows', 'linux'}
platform = asset.platform_base.lower()
if platform in icon:
return platform
return default
def serialize_assets(self, assets, node_key=None):
if node_key is None:
get_pid = lambda asset: getattr(asset, 'parent_key', '')
else:
get_pid = lambda asset: node_key
data = [
{
'id': str(asset.id),
'name': asset.hostname,
'title': asset.ip,
'pId': get_pid(asset),
'isParent': False,
'open': False,
'iconSkin': self.get_platform(asset),
'chkDisabled': not asset.is_active,
'meta': {
'type': 'asset',
'asset': {
'id': asset.id,
'hostname': asset.hostname,
'ip': asset.ip,
'protocols': asset.protocols_as_list,
'platform': asset.platform_base,
},
}
}
for asset in assets
]
return data
class FilterAssetByNodeMixin:
pagination_class = AssetLimitOffsetPagination
@lazyproperty
def is_query_node_all_assets(self):
return is_query_node_all_assets(self.request)
@lazyproperty
def node(self):
return get_node(self.request)

View File

@@ -1,24 +1,16 @@
# ~*~ coding: utf-8 ~*~
from functools import partial
from collections import namedtuple, defaultdict
from collections import namedtuple
from rest_framework import status
from rest_framework.serializers import ValidationError
from rest_framework.response import Response
from django.utils.translation import ugettext_lazy as _
from django.shortcuts import get_object_or_404, Http404
from django.utils.decorators import method_decorator
from django.db.models.signals import m2m_changed
from common.exceptions import SomeoneIsDoingThis
from common.const.signals import PRE_REMOVE, POST_REMOVE
from assets.models import Asset
from common.utils import get_logger, get_object_or_none
from common.tree import TreeNodeSerializer
from common.const.distributed_lock_key import UPDATE_NODE_TREE_LOCK_KEY
from orgs.mixins.api import OrgModelViewSet
from orgs.mixins import generics
from orgs.lock import org_level_transaction_lock
from ..hands import IsOrgAdmin
from ..models import Node
from ..tasks import (
@@ -26,13 +18,12 @@ from ..tasks import (
test_node_assets_connectivity_manual,
)
from .. import serializers
from .mixin import SerializeToTreeNodeMixin
logger = get_logger(__file__)
__all__ = [
'NodeViewSet', 'NodeChildrenApi', 'NodeAssetsApi',
'NodeAddAssetsApi', 'NodeRemoveAssetsApi', 'MoveAssetsToNodeApi',
'NodeAddAssetsApi', 'NodeRemoveAssetsApi', 'NodeReplaceAssetsApi',
'NodeAddChildrenApi', 'NodeListAsTreeApi',
'NodeChildrenAsTreeApi',
'NodeTaskCreateApi',
@@ -145,7 +136,7 @@ class NodeChildrenApi(generics.ListCreateAPIView):
return queryset
class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
class NodeChildrenAsTreeApi(NodeChildrenApi):
"""
节点子节点作为树返回,
[
@@ -159,23 +150,31 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
"""
model = Node
serializer_class = TreeNodeSerializer
http_method_names = ['get']
def list(self, request, *args, **kwargs):
nodes = self.get_queryset().order_by('value')
nodes = self.serialize_nodes(nodes, with_asset_amount=True)
assets = self.get_assets()
data = [*nodes, *assets]
return Response(data=data)
def get_queryset(self):
queryset = super().get_queryset()
queryset = [node.as_tree_node() for node in queryset]
queryset = self.add_assets_if_need(queryset)
queryset = sorted(queryset)
return queryset
def get_assets(self):
def add_assets_if_need(self, queryset):
include_assets = self.request.query_params.get('assets', '0') == '1'
if not include_assets:
return []
return queryset
assets = self.instance.get_assets().only(
"id", "hostname", "ip", "os",
"org_id", "protocols",
)
return self.serialize_assets(assets, self.instance.key)
for asset in assets:
queryset.append(asset.as_tree_node(self.instance))
return queryset
def check_need_refresh_nodes(self):
if self.request.query_params.get('refresh', '0') == '1':
Node.refresh_nodes()
class NodeAssetsApi(generics.ListAPIView):
@@ -209,8 +208,6 @@ class NodeAddChildrenApi(generics.UpdateAPIView):
return Response("OK")
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class NodeAddAssetsApi(generics.UpdateAPIView):
model = Node
serializer_class = serializers.NodeAssetsSerializer
@@ -223,8 +220,6 @@ class NodeAddAssetsApi(generics.UpdateAPIView):
instance.assets.add(*tuple(assets))
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class NodeRemoveAssetsApi(generics.UpdateAPIView):
model = Node
serializer_class = serializers.NodeAssetsSerializer
@@ -233,17 +228,15 @@ class NodeRemoveAssetsApi(generics.UpdateAPIView):
def perform_update(self, serializer):
assets = serializer.validated_data.get('assets')
node = self.get_object()
node.assets.remove(*assets)
# 把孤儿资产添加到 root 节点
orphan_assets = Asset.objects.filter(id__in=[a.id for a in assets], nodes__isnull=True).distinct()
Node.org_root().assets.add(*orphan_assets)
instance = self.get_object()
if instance != Node.org_root():
instance.assets.remove(*tuple(assets))
else:
assets = [asset for asset in assets if asset.nodes.count() > 1]
instance.assets.remove(*tuple(assets))
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class MoveAssetsToNodeApi(generics.UpdateAPIView):
class NodeReplaceAssetsApi(generics.UpdateAPIView):
model = Node
serializer_class = serializers.NodeAssetsSerializer
permission_classes = (IsOrgAdmin,)
@@ -251,39 +244,9 @@ class MoveAssetsToNodeApi(generics.UpdateAPIView):
def perform_update(self, serializer):
assets = serializer.validated_data.get('assets')
node = self.get_object()
self.remove_old_nodes(assets)
node.assets.add(*assets)
def remove_old_nodes(self, assets):
m2m_model = Asset.nodes.through
# 查询资产与节点关系表,查出要移动资产与节点的所有关系
relates = m2m_model.objects.filter(asset__in=assets).values_list('asset_id', 'node_id')
if relates:
# 对关系以资产进行分组,用来发 `reverse=False` 信号
asset_nodes_mapper = defaultdict(set)
for asset_id, node_id in relates:
asset_nodes_mapper[asset_id].add(node_id)
# 组建一个资产 id -> Asset 的 mapper
asset_mapper = {asset.id: asset for asset in assets}
# 创建删除关系信号发送函数
senders = []
for asset_id, node_id_set in asset_nodes_mapper.items():
senders.append(partial(m2m_changed.send, sender=m2m_model, instance=asset_mapper[asset_id],
reverse=False, model=Node, pk_set=node_id_set))
# 发送 pre 信号
[sender(action=PRE_REMOVE) for sender in senders]
num = len(relates)
asset_ids, node_ids = zip(*relates)
# 删除之前的关系
rows, _i = m2m_model.objects.filter(asset_id__in=asset_ids, node_id__in=node_ids).delete()
if rows != num:
raise SomeoneIsDoingThis
# 发送 post 信号
[sender(action=POST_REMOVE) for sender in senders]
instance = self.get_object()
for asset in assets:
asset.nodes.set([instance])
class NodeTaskCreateApi(generics.CreateAPIView):
@@ -304,6 +267,7 @@ class NodeTaskCreateApi(generics.CreateAPIView):
@staticmethod
def refresh_nodes_cache():
Node.refresh_nodes()
Task = namedtuple('Task', ['id'])
task = Task(id="0")
return task

View File

@@ -1,6 +0,0 @@
from rest_framework import status
from common.exceptions import JMSException
class NodeIsBeingUpdatedByOthers(JMSException):
status_code = status.HTTP_409_CONFLICT

View File

@@ -1,12 +1,12 @@
# -*- coding: utf-8 -*-
#
from rest_framework.compat import coreapi, coreschema
import coreapi
from rest_framework import filters
from django.db.models import Q
from .models import Label
from assets.utils import is_query_node_all_assets, get_node
from common.utils import dict_get_any, is_uuid, get_object_or_none
from .models import Node, Label
class AssetByNodeFilterBackend(filters.BaseFilterBackend):
@@ -21,54 +21,47 @@ class AssetByNodeFilterBackend(filters.BaseFilterBackend):
for field in self.fields
]
def filter_node_related_all(self, queryset, node):
return queryset.filter(
Q(nodes__key__istartswith=f'{node.key}:') |
Q(nodes__key=node.key)
).distinct()
@staticmethod
def is_query_all(request):
query_all_arg = request.query_params.get('all')
show_current_asset_arg = request.query_params.get('show_current_asset')
def filter_node_related_direct(self, queryset, node):
return queryset.filter(nodes__key=node.key).distinct()
query_all = query_all_arg == '1'
if show_current_asset_arg is not None:
query_all = show_current_asset_arg != '1'
return query_all
@staticmethod
def get_query_node(request):
node_id = dict_get_any(request.query_params, ['node', 'node_id'])
if not node_id:
return None, False
if is_uuid(node_id):
node = get_object_or_none(Node, id=node_id)
else:
node = get_object_or_none(Node, key=node_id)
return node, True
@staticmethod
def perform_query(pattern, queryset):
return queryset.filter(nodes__key__regex=pattern).distinct()
def filter_queryset(self, request, queryset, view):
node = get_node(request)
if node is None:
node, has_query_arg = self.get_query_node(request)
if not has_query_arg:
return queryset
query_all = is_query_node_all_assets(request)
if query_all:
return self.filter_node_related_all(queryset, node)
else:
return self.filter_node_related_direct(queryset, node)
class FilterAssetByNodeFilterBackend(filters.BaseFilterBackend):
"""
需要与 `assets.api.mixin.FilterAssetByNodeMixin` 配合使用
"""
fields = ['node', 'all']
def get_schema_fields(self, view):
return [
coreapi.Field(
name=field, location='query', required=False,
type='string', example='', description='', schema=None,
)
for field in self.fields
]
def filter_queryset(self, request, queryset, view):
node = view.node
if node is None:
return queryset
query_all = view.is_query_node_all_assets
query_all = self.is_query_all(request)
if query_all:
return queryset.filter(
Q(nodes__key__istartswith=f'{node.key}:') |
Q(nodes__key=node.key)
).distinct()
pattern = node.get_all_children_pattern(with_self=True)
else:
return queryset.filter(nodes__key=node.key).distinct()
# pattern = node.get_children_key_pattern(with_self=True)
# 只显示当前节点下资产
pattern = r"^{}$".format(node.key)
return self.perform_query(pattern, queryset)
class LabelFilterBackend(filters.BaseFilterBackend):
@@ -120,32 +113,7 @@ class LabelFilterBackend(filters.BaseFilterBackend):
class AssetRelatedByNodeFilterBackend(AssetByNodeFilterBackend):
def filter_node_related_all(self, queryset, node):
return queryset.filter(
Q(asset__nodes__key__istartswith=f'{node.key}:') |
Q(asset__nodes__key=node.key)
).distinct()
@staticmethod
def perform_query(pattern, queryset):
return queryset.filter(asset__nodes__key__regex=pattern).distinct()
def filter_node_related_direct(self, queryset, node):
return queryset.filter(asset__nodes__key=node.key).distinct()
class IpInFilterBackend(filters.BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
ips = request.query_params.get('ips')
if not ips:
return queryset
ip_list = [i.strip() for i in ips.split(',')]
queryset = queryset.filter(ip__in=ip_list)
return queryset
def get_schema_fields(self, view):
return [
coreapi.Field(
name='ips', location='query', required=False, type='string',
schema=coreschema.String(
title='ips',
description='ip in filter'
)
)
]

View File

@@ -1,18 +0,0 @@
# Generated by Django 2.2.10 on 2020-07-11 09:40
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0049_systemuser_sftp_root'),
]
operations = [
migrations.AlterField(
model_name='asset',
name='created_by',
field=models.CharField(blank=True, max_length=128, null=True, verbose_name='Created by'),
),
]

View File

@@ -1,22 +0,0 @@
# Generated by Django 2.2.10 on 2020-07-13 03:43
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0050_auto_20200711_1740'),
]
operations = [
migrations.AlterField(
model_name='domain',
name='name',
field=models.CharField(max_length=128, verbose_name='Name'),
),
migrations.AlterUniqueTogether(
name='domain',
unique_together={('org_id', 'name')},
),
]

View File

@@ -1,22 +0,0 @@
# Generated by Django 2.2.10 on 2020-07-15 07:35
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0051_auto_20200713_1143'),
]
operations = [
migrations.AlterField(
model_name='commandfilter',
name='name',
field=models.CharField(max_length=64, verbose_name='Name'),
),
migrations.AlterUniqueTogether(
name='commandfilter',
unique_together={('org_id', 'name')},
),
]

View File

@@ -1,34 +0,0 @@
# Generated by Django 2.2.10 on 2020-07-23 04:32
import django.core.validators
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0052_auto_20200715_1535'),
]
operations = [
migrations.AlterField(
model_name='adminuser',
name='username',
field=models.CharField(blank=True, db_index=True, max_length=128, validators=[django.core.validators.RegexValidator('^[0-9a-zA-Z_@\\-\\.]*$', 'Special char not allowed')], verbose_name='Username'),
),
migrations.AlterField(
model_name='authbook',
name='username',
field=models.CharField(blank=True, db_index=True, max_length=128, validators=[django.core.validators.RegexValidator('^[0-9a-zA-Z_@\\-\\.]*$', 'Special char not allowed')], verbose_name='Username'),
),
migrations.AlterField(
model_name='gateway',
name='username',
field=models.CharField(blank=True, db_index=True, max_length=128, validators=[django.core.validators.RegexValidator('^[0-9a-zA-Z_@\\-\\.]*$', 'Special char not allowed')], verbose_name='Username'),
),
migrations.AlterField(
model_name='systemuser',
name='username',
field=models.CharField(blank=True, db_index=True, max_length=128, validators=[django.core.validators.RegexValidator('^[0-9a-zA-Z_@\\-\\.]*$', 'Special char not allowed')], verbose_name='Username'),
),
]

View File

@@ -1,23 +0,0 @@
# Generated by Django 2.2.13 on 2020-08-07 02:32
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0053_auto_20200723_1232'),
]
operations = [
migrations.AddField(
model_name='systemuser',
name='token',
field=models.TextField(default='', verbose_name='Token'),
),
migrations.AlterField(
model_name='systemuser',
name='protocol',
field=models.CharField(choices=[('ssh', 'ssh'), ('rdp', 'rdp'), ('telnet', 'telnet'), ('vnc', 'vnc'), ('mysql', 'mysql'), ('k8s', 'k8s')], default='ssh', max_length=16, verbose_name='Protocol'),
),
]

View File

@@ -1,23 +0,0 @@
# Generated by Django 2.2.13 on 2020-08-11 10:45
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0054_auto_20200807_1032'),
]
operations = [
migrations.AddField(
model_name='systemuser',
name='home',
field=models.CharField(blank=True, default='', max_length=4096, verbose_name='Home'),
),
migrations.AddField(
model_name='systemuser',
name='system_groups',
field=models.CharField(blank=True, default='', max_length=4096, verbose_name='System groups'),
),
]

View File

@@ -1,23 +0,0 @@
# Generated by Django 2.2.13 on 2020-09-04 09:51
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0055_auto_20200811_1845'),
]
operations = [
migrations.AddField(
model_name='node',
name='assets_amount',
field=models.IntegerField(default=0),
),
migrations.AddField(
model_name='node',
name='parent_key',
field=models.CharField(db_index=True, default='', max_length=64, verbose_name='Parent key'),
),
]

View File

@@ -1,38 +0,0 @@
# Generated by Django 2.2.13 on 2020-08-21 08:20
from django.db import migrations
from django.db.models import Q
def fill_node_value(apps, schema_editor):
Node = apps.get_model('assets', 'Node')
Asset = apps.get_model('assets', 'Asset')
node_queryset = Node.objects.all()
node_amount = node_queryset.count()
width = len(str(node_amount))
print('\n')
for i, node in enumerate(node_queryset):
print(f'\t{i+1:0>{width}}/{node_amount} compute node[{node.key}]`s assets_amount ...')
assets_amount = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes=node)
).distinct().count()
key = node.key
try:
parent_key = key[:key.rindex(':')]
except ValueError:
parent_key = ''
node.assets_amount = assets_amount
node.parent_key = parent_key
node.save()
print(' ' + '.'*65, end='')
class Migration(migrations.Migration):
dependencies = [
('assets', '0056_auto_20200904_1751'),
]
operations = [
migrations.RunPython(fill_node_value)
]

View File

@@ -47,10 +47,6 @@ class AssetManager(OrgManager):
)
class AssetOrgManager(OrgManager):
pass
class AssetQuerySet(models.QuerySet):
def active(self):
return self.filter(is_active=True)
@@ -225,12 +221,11 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
hostname_raw = models.CharField(max_length=128, blank=True, null=True, verbose_name=_('Hostname raw'))
labels = models.ManyToManyField('assets.Label', blank=True, related_name='assets', verbose_name=_("Labels"))
created_by = models.CharField(max_length=128, null=True, blank=True, verbose_name=_('Created by'))
created_by = models.CharField(max_length=32, null=True, blank=True, verbose_name=_('Created by'))
date_created = models.DateTimeField(auto_now_add=True, null=True, blank=True, verbose_name=_('Date created'))
comment = models.TextField(max_length=128, default='', blank=True, verbose_name=_('Comment'))
objects = AssetManager.from_queryset(AssetQuerySet)()
org_objects = AssetOrgManager.from_queryset(AssetQuerySet)()
_connectivity = None
def __str__(self):
@@ -249,6 +244,10 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
def platform_base(self):
return self.platform.base
@lazyproperty
def admin_user_display(self):
return self.admin_user.name
@lazyproperty
def admin_user_username(self):
"""求可连接性时直接用用户名去取避免再查一次admin user
@@ -355,4 +354,36 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
class Meta:
unique_together = [('org_id', 'hostname')]
verbose_name = _("Asset")
ordering = ['-date_created']
@classmethod
def generate_fake(cls, count=100):
from .user import AdminUser, SystemUser
from random import seed, choice
from django.db import IntegrityError
from .node import Node
from orgs.utils import get_current_org
from orgs.models import Organization
org = get_current_org()
if not org or not org.is_real():
Organization.default().change_to()
nodes = list(Node.objects.all())
seed()
for i in range(count):
ip = [str(i) for i in random.sample(range(255), 4)]
asset = cls(ip='.'.join(ip),
hostname='.'.join(ip),
admin_user=choice(AdminUser.objects.all()),
created_by='Fake')
try:
asset.save()
asset.protocols = 'ssh/22'
if nodes and len(nodes) > 3:
_nodes = random.sample(nodes, 3)
else:
_nodes = [Node.default_node()]
asset.nodes.set(_nodes)
logger.debug('Generate fake asset : %s' % asset.ip)
except IntegrityError:
print('Error continue')
continue

View File

@@ -3,6 +3,7 @@
from django.db import models, transaction
from django.db.models import Max
from django.core.cache import cache
from django.utils.translation import ugettext_lazy as _
from orgs.mixins.models import OrgManager
@@ -58,17 +59,19 @@ class AuthBook(BaseUser):
"""
username = kwargs['username']
asset = kwargs['asset']
with transaction.atomic():
# 使用select_for_update限制并发创建相同的username、asset条目
instances = cls.objects.select_for_update().filter(username=username, asset=asset)
instances.filter(is_latest=True).update(is_latest=False)
max_version = cls.get_max_version(username, asset)
kwargs.update({
'version': max_version + 1,
'is_latest': True
})
obj = cls.objects.create(**kwargs)
return obj
key_lock = 'KEY_LOCK_CREATE_AUTH_BOOK_{}_{}'.format(username, asset.id)
with cache.lock(key_lock):
with transaction.atomic():
cls.objects.filter(
username=username, asset=asset, is_latest=True
).update(is_latest=False)
max_version = cls.get_max_version(username, asset)
kwargs.update({
'version': max_version + 1,
'is_latest': True
})
obj = cls.objects.create(**kwargs)
return obj
@property
def connectivity(self):

View File

@@ -158,11 +158,9 @@ class AuthMixin:
if update_fields:
self.save(update_fields=update_fields)
def has_special_auth(self, asset=None, username=None):
def has_special_auth(self, asset=None):
from .authbook import AuthBook
if username is None:
username = self.username
queryset = AuthBook.objects.filter(username=username)
queryset = AuthBook.objects.filter(username=self.username)
if asset:
queryset = queryset.filter(asset=asset)
return queryset.exists()
@@ -232,7 +230,7 @@ class AuthMixin:
class BaseUser(OrgModelMixin, AuthMixin, ConnectivityMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=128, verbose_name=_('Name'))
username = models.CharField(max_length=128, blank=True, verbose_name=_('Username'), validators=[alphanumeric], db_index=True)
username = models.CharField(max_length=32, blank=True, verbose_name=_('Username'), validators=[alphanumeric], db_index=True)
password = fields.EncryptCharField(max_length=256, blank=True, null=True, verbose_name=_('Password'))
private_key = fields.EncryptTextField(blank=True, null=True, verbose_name=_('SSH private key'))
public_key = fields.EncryptTextField(blank=True, null=True, verbose_name=_('SSH public key'))

View File

@@ -38,3 +38,27 @@ class Cluster(models.Model):
class Meta:
ordering = ['name']
verbose_name = _("Cluster")
@classmethod
def generate_fake(cls, count=5):
from random import seed, choice
import forgery_py
from django.db import IntegrityError
seed()
for i in range(count):
cluster = cls(name=forgery_py.name.full_name(),
bandwidth='200M',
contact=forgery_py.name.full_name(),
phone=forgery_py.address.phone(),
address=forgery_py.address.city() + forgery_py.address.street_address(),
# operator=choice(['北京联通', '北京电信', 'BGP全网通']),
operator=choice([_('Beijing unicom'), _('Beijing telecom'), _('BGP full netcom')]),
comment=forgery_py.lorem_ipsum.sentence(),
created_by='Fake')
try:
cluster.save()
logger.debug('Generate fake asset group: %s' % cluster.name)
except IntegrityError:
print('Error continue')
continue

View File

@@ -18,7 +18,7 @@ __all__ = [
class CommandFilter(OrgModelMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=64, verbose_name=_("Name"))
name = models.CharField(max_length=64, unique=True, verbose_name=_("Name"))
is_active = models.BooleanField(default=True, verbose_name=_('Is active'))
comment = models.TextField(blank=True, default='', verbose_name=_("Comment"))
date_created = models.DateTimeField(auto_now_add=True)
@@ -29,7 +29,6 @@ class CommandFilter(OrgModelMixin):
return self.name
class Meta:
unique_together = [('org_id', 'name')]
verbose_name = _("Command filter")

View File

@@ -17,14 +17,13 @@ __all__ = ['Domain', 'Gateway']
class Domain(OrgModelMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=128, verbose_name=_('Name'))
name = models.CharField(max_length=128, unique=True, verbose_name=_('Name'))
comment = models.TextField(blank=True, verbose_name=_('Comment'))
date_created = models.DateTimeField(auto_now_add=True, null=True,
verbose_name=_('Date created'))
class Meta:
verbose_name = _("Domain")
unique_together = [('org_id', 'name')]
def __str__(self):
return self.name

View File

@@ -18,11 +18,3 @@ class FavoriteAsset(CommonModelMixin):
@classmethod
def get_user_favorite_assets_id(cls, user):
return cls.objects.filter(user=user).values_list('asset', flat=True)
@classmethod
def get_user_favorite_assets(cls, user):
from assets.models import Asset
from perms.utils.user_asset_permission import get_user_granted_all_assets
asset_ids = get_user_granted_all_assets(user).values_list('id', flat=True)
query_name = cls.asset.field.related_query_name()
return Asset.org_objects.filter(**{f'{query_name}__user_id': user.id}, id__in=asset_ids).distinct()

View File

@@ -33,3 +33,21 @@ class AssetGroup(models.Model):
def initial(cls):
asset_group = cls(name=_('Default'), comment=_('Default asset group'))
asset_group.save()
@classmethod
def generate_fake(cls, count=100):
from random import seed
import forgery_py
from django.db import IntegrityError
seed()
for i in range(count):
group = cls(name=forgery_py.name.full_name(),
comment=forgery_py.lorem_ipsum.sentence(),
created_by='Fake')
try:
group.save()
logger.debug('Generate fake asset group: %s' % group.name)
except IntegrityError:
print('Error continue')
continue

View File

@@ -2,35 +2,132 @@
#
import uuid
import re
import time
from django.db import models, transaction
from django.db.models import Q
from django.db.utils import IntegrityError
from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ugettext
from django.db.transaction import atomic
from django.core.cache import cache
from common.utils import get_logger
from common.utils.common import lazyproperty
from common.utils import get_logger, lazyproperty
from orgs.mixins.models import OrgModelMixin, OrgManager
from orgs.utils import get_current_org, tmp_to_org, current_org
from orgs.models import Organization
__all__ = ['Node', 'FamilyMixin', 'compute_parent_key']
__all__ = ['Node']
logger = get_logger(__name__)
def compute_parent_key(key):
try:
return key[:key.rindex(':')]
except ValueError:
return ''
class NodeQuerySet(models.QuerySet):
def delete(self):
raise NotImplementedError
raise PermissionError("Bulk delete node deny")
class TreeCache:
updated_time_cache_key = 'NODE_TREE_UPDATED_AT_{}'
cache_time = 3600
assets_updated_time_cache_key = 'NODE_TREE_ASSETS_UPDATED_AT_{}'
def __init__(self, tree, org_id):
now = time.time()
self.created_time = now
self.assets_created_time = now
self.tree = tree
self.org_id = org_id
def _has_changed(self, tp="tree"):
if tp == "assets":
key = self.assets_updated_time_cache_key.format(self.org_id)
else:
key = self.updated_time_cache_key.format(self.org_id)
updated_time = cache.get(key, 0)
if updated_time > self.created_time:
return True
else:
return False
@classmethod
def set_changed(cls, tp="tree", t=None, org_id=None):
if org_id is None:
org_id = current_org.id
if tp == "assets":
key = cls.assets_updated_time_cache_key.format(org_id)
else:
key = cls.updated_time_cache_key.format(org_id)
ttl = cls.cache_time
if not t:
t = time.time()
cache.set(key, t, ttl)
def tree_has_changed(self):
return self._has_changed("tree")
def set_tree_changed(self, t=None):
logger.debug("Set tree tree changed")
self.__class__.set_changed(t=t, tp="tree")
def assets_has_changed(self):
return self._has_changed("assets")
def set_tree_assets_changed(self, t=None):
logger.debug("Set tree assets changed")
self.__class__.set_changed(t=t, tp="assets")
def get(self):
if self.tree_has_changed():
self.renew()
return self.tree
if self.assets_has_changed():
self.tree.init_assets()
return self.tree
def renew(self):
new_obj = self.__class__.new(self.org_id)
self.tree = new_obj.tree
self.created_time = new_obj.created_time
self.assets_created_time = new_obj.assets_created_time
@classmethod
def new(cls, org_id=None):
from ..utils import TreeService
logger.debug("Create node tree")
if not org_id:
org_id = current_org.id
with tmp_to_org(org_id):
tree = TreeService.new()
obj = cls(tree, org_id)
obj.tree = tree
return obj
class TreeMixin:
_org_tree_map = {}
@classmethod
def tree(cls):
org_id = current_org.org_id()
t = cls.get_local_tree_cache(org_id)
if t is None:
t = TreeCache.new()
cls._org_tree_map[org_id] = t
return t.get()
@classmethod
def get_local_tree_cache(cls, org_id=None):
t = cls._org_tree_map.get(org_id)
return t
@classmethod
def refresh_tree(cls, t=None):
TreeCache.set_changed(tp="tree", t=t, org_id=current_org.id)
@classmethod
def refresh_node_assets(cls, t=None):
TreeCache.set_changed(tp="assets", t=t, org_id=current_org.id)
class FamilyMixin:
@@ -41,16 +138,16 @@ class FamilyMixin:
@staticmethod
def clean_children_keys(nodes_keys):
sort_key = lambda k: [int(i) for i in k.split(':')]
nodes_keys = sorted(list(nodes_keys), key=sort_key)
nodes_keys = sorted(list(nodes_keys), key=lambda x: (len(x), x))
nodes_keys_clean = []
base_key = ''
for key in nodes_keys:
if key.startswith(base_key + ':'):
continue
nodes_keys_clean.append(key)
base_key = key
for key in nodes_keys[::-1]:
found = False
for k in nodes_keys:
if key.startswith(k + ':'):
found = True
break
if not found:
nodes_keys_clean.append(key)
return nodes_keys_clean
@classmethod
@@ -78,16 +175,13 @@ class FamilyMixin:
return re.match(children_pattern, self.key)
def get_children(self, with_self=False):
q = Q(parent_key=self.key)
if with_self:
q |= Q(key=self.key)
return Node.objects.filter(q)
pattern = self.get_children_key_pattern(with_self=with_self)
return Node.objects.filter(key__regex=pattern)
def get_all_children(self, with_self=False):
q = Q(key__istartswith=f'{self.key}:')
if with_self:
q |= Q(key=self.key)
return Node.objects.filter(q)
pattern = self.get_all_children_pattern(with_self=with_self)
children = Node.objects.filter(key__regex=pattern)
return children
@property
def children(self):
@@ -97,30 +191,14 @@ class FamilyMixin:
def all_children(self):
return self.get_all_children(with_self=False)
def create_child(self, value=None, _id=None):
with atomic(savepoint=False):
def create_child(self, value, _id=None):
with transaction.atomic():
child_key = self.get_next_child_key()
if value is None:
value = child_key
child = self.__class__.objects.create(
id=_id, key=child_key, value=value, parent_key=self.key,
id=_id, key=child_key, value=value
)
return child
def get_or_create_child(self, value, _id=None):
"""
:return: Node, bool (created)
"""
children = self.get_children()
exist = children.filter(value=value).exists()
if exist:
child = children.filter(value=value).first()
created = False
else:
child = self.create_child(value, _id)
created = True
return child, created
def get_next_child_key(self):
mark = self.child_mark
self.child_mark += 1
@@ -163,13 +241,10 @@ class FamilyMixin:
ancestor_keys = self.get_ancestor_keys(with_self=with_self)
return self.__class__.objects.filter(key__in=ancestor_keys)
# @property
# def parent_key(self):
# parent_key = ":".join(self.key.split(":")[:-1])
# return parent_key
def compute_parent_key(self):
return compute_parent_key(self.key)
@property
def parent_key(self):
parent_key = ":".join(self.key.split(":")[:-1])
return parent_key
def is_parent(self, other):
return other.is_children(self)
@@ -211,33 +286,39 @@ class FamilyMixin:
return [*tuple(ancestors), self, *tuple(children)]
class FullValueMixin:
key = ''
@lazyproperty
def full_value(self):
if self.is_org_root():
return self.value
value = self.tree().get_node_full_tag(self.key)
return value
class NodeAssetsMixin:
key = ''
id = None
@lazyproperty
def assets_amount(self):
amount = self.tree().assets_amount(self.key)
return amount
def get_all_assets(self):
from .asset import Asset
q = Q(nodes__key__startswith=f'{self.key}:') | Q(nodes__key=self.key)
return Asset.objects.filter(q).distinct()
@classmethod
def get_node_all_assets_by_key_v2(cls, key):
# 最初的写法是:
# Asset.objects.filter(Q(nodes__key__startswith=f'{node.key}:') | Q(nodes__id=node.id))
# 可是 startswith 会导致表关联时 Asset 索引失效
from .asset import Asset
node_ids = cls.objects.filter(
Q(key__startswith=f'{key}:') |
Q(key=key)
).values_list('id', flat=True).distinct()
assets = Asset.objects.filter(
nodes__id__in=list(node_ids)
).distinct()
return assets
if self.is_org_root():
return Asset.objects.filter(org_id=self.org_id)
pattern = '^{0}$|^{0}:'.format(self.key)
return Asset.objects.filter(nodes__key__regex=pattern).distinct()
def get_assets(self):
from .asset import Asset
assets = Asset.objects.filter(nodes=self)
if self.is_org_root():
assets = Asset.objects.filter(Q(nodes=self) | Q(nodes__isnull=True))
else:
assets = Asset.objects.filter(nodes=self)
return assets.distinct()
def get_valid_assets(self):
@@ -246,54 +327,51 @@ class NodeAssetsMixin:
def get_all_valid_assets(self):
return self.get_all_assets().valid()
@classmethod
def _get_nodes_all_assets(cls, nodes_keys):
"""
当节点比较多的时候,这种正则方式性能差极了
:param nodes_keys:
:return:
"""
from .asset import Asset
nodes_keys = cls.clean_children_keys(nodes_keys)
nodes_children_pattern = set()
for key in nodes_keys:
children_pattern = cls.get_node_all_children_key_pattern(key)
nodes_children_pattern.add(children_pattern)
pattern = '|'.join(nodes_children_pattern)
return Asset.objects.filter(nodes__key__regex=pattern).distinct()
@classmethod
def get_nodes_all_assets_ids(cls, nodes_keys):
assets_ids = cls.get_nodes_all_assets(nodes_keys).values_list('id', flat=True)
nodes_keys = cls.clean_children_keys(nodes_keys)
assets_ids = set()
for key in nodes_keys:
node_assets_ids = cls.tree().all_assets(key)
assets_ids.update(set(node_assets_ids))
return assets_ids
@classmethod
def get_nodes_all_assets(cls, nodes_keys, extra_assets_ids=None):
from .asset import Asset
nodes_keys = cls.clean_children_keys(nodes_keys)
q = Q()
node_ids = ()
for key in nodes_keys:
q |= Q(key__startswith=f'{key}:')
q |= Q(key=key)
if q:
node_ids = Node.objects.filter(q).distinct().values_list('id', flat=True)
q = Q(nodes__id__in=list(node_ids))
assets_ids = cls.get_nodes_all_assets_ids(nodes_keys)
if extra_assets_ids:
q |= Q(id__in=extra_assets_ids)
if q:
return Asset.org_objects.filter(q).distinct()
else:
return Asset.objects.none()
assets_ids.update(set(extra_assets_ids))
return Asset.objects.filter(id__in=assets_ids)
class SomeNodesMixin:
key = ''
default_key = '1'
default_value = 'Default'
ungrouped_key = '-10'
ungrouped_value = _('ungrouped')
empty_key = '-11'
empty_value = _("empty")
@classmethod
def default_node(cls):
with tmp_to_org(Organization.default()):
defaults = {'value': cls.default_value}
try:
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key,
)
except IntegrityError as e:
logger.error("Create default node failed: {}".format(e))
cls.modify_other_org_root_node_key()
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key,
)
return obj
favorite_key = '-12'
favorite_value = _("favorite")
def is_default_node(self):
return self.key == self.default_key
@@ -328,15 +406,51 @@ class SomeNodesMixin:
@classmethod
def org_root(cls):
root = cls.objects.filter(parent_key='').exclude(key__startswith='-')
root = cls.objects.filter(key__regex=r'^[0-9]+$')
if root:
return root[0]
else:
return cls.create_org_root_node()
@classmethod
def ungrouped_node(cls):
with tmp_to_org(Organization.system()):
defaults = {'value': cls.ungrouped_value}
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.ungrouped_key
)
return obj
@classmethod
def default_node(cls):
with tmp_to_org(Organization.default()):
defaults = {'value': cls.default_value}
try:
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key,
)
except IntegrityError as e:
logger.error("Create default node failed: {}".format(e))
cls.modify_other_org_root_node_key()
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key,
)
return obj
@classmethod
def favorite_node(cls):
with tmp_to_org(Organization.system()):
defaults = {'value': cls.favorite_value}
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.favorite_key
)
return obj
@classmethod
def initial_some_nodes(cls):
cls.default_node()
cls.ungrouped_node()
cls.favorite_node()
@classmethod
def modify_other_org_root_node_key(cls):
@@ -368,15 +482,12 @@ class SomeNodesMixin:
logger.info('Modify key ( {} > {} )'.format(old_key, new_key))
class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
class Node(OrgModelMixin, SomeNodesMixin, TreeMixin, FamilyMixin, FullValueMixin, NodeAssetsMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
key = models.CharField(unique=True, max_length=64, verbose_name=_("Key")) # '1:1:1:1'
value = models.CharField(max_length=128, verbose_name=_("Value"))
child_mark = models.IntegerField(default=0)
date_create = models.DateTimeField(auto_now_add=True)
parent_key = models.CharField(max_length=64, verbose_name=_("Parent key"),
db_index=True, default='')
assets_amount = models.IntegerField(default=0)
objects = OrgManager.from_queryset(NodeQuerySet)()
is_node = True
@@ -411,20 +522,18 @@ class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
def name(self):
return self.value
@lazyproperty
def full_value(self):
# 不要在列表中调用该属性
values = self.__class__.objects.filter(
key__in=self.get_ancestor_keys()
).values_list('key', 'value')
values = [v for k, v in sorted(values, key=lambda x: len(x[0]))]
values.append(self.value)
return ' / '.join(values)
@property
def level(self):
return len(self.key.split(':'))
@classmethod
def refresh_nodes(cls):
cls.refresh_tree()
@classmethod
def refresh_assets(cls):
cls.refresh_node_assets()
def as_tree_node(self):
from common.tree import TreeNode
name = '{} ({})'.format(self.value, self.assets_amount)
@@ -458,3 +567,20 @@ class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
if self.has_children_or_has_assets():
return
return super().delete(using=using, keep_parents=keep_parents)
@classmethod
def generate_fake(cls, count=100):
import random
org = get_current_org()
if not org or not org.is_real():
Organization.default().change_to()
nodes = list(cls.objects.all())
if count > 100:
length = 100
else:
length = count
for i in range(length):
node = random.choice(nodes)
child = node.create_child('Node {}'.format(i))
print("{}. {}".format(i, child))

View File

@@ -9,7 +9,6 @@ from django.utils.translation import ugettext_lazy as _
from django.core.validators import MinValueValidator, MaxValueValidator
from common.utils import signer
from common.fields.model import JsonListCharField
from .base import BaseUser
from .asset import Asset
@@ -65,6 +64,26 @@ class AdminUser(BaseUser):
unique_together = [('name', 'org_id')]
verbose_name = _("Admin user")
@classmethod
def generate_fake(cls, count=10):
from random import seed
import forgery_py
from django.db import IntegrityError
seed()
for i in range(count):
obj = cls(name=forgery_py.name.full_name(),
username=forgery_py.internet.user_name(),
password=forgery_py.lorem_ipsum.word(),
comment=forgery_py.lorem_ipsum.sentence(),
created_by='Fake')
try:
obj.save()
logger.debug('Generate fake asset group: %s' % obj.name)
except IntegrityError:
print('Error continue')
continue
class SystemUser(BaseUser):
PROTOCOL_SSH = 'ssh'
@@ -72,14 +91,12 @@ class SystemUser(BaseUser):
PROTOCOL_TELNET = 'telnet'
PROTOCOL_VNC = 'vnc'
PROTOCOL_MYSQL = 'mysql'
PROTOCOL_K8S = 'k8s'
PROTOCOL_CHOICES = (
(PROTOCOL_SSH, 'ssh'),
(PROTOCOL_RDP, 'rdp'),
(PROTOCOL_TELNET, 'telnet'),
(PROTOCOL_VNC, 'vnc'),
(PROTOCOL_MYSQL, 'mysql'),
(PROTOCOL_K8S, 'k8s'),
)
LOGIN_AUTO = 'auto'
@@ -101,9 +118,6 @@ class SystemUser(BaseUser):
login_mode = models.CharField(choices=LOGIN_MODE_CHOICES, default=LOGIN_AUTO, max_length=10, verbose_name=_('Login mode'))
cmd_filters = models.ManyToManyField('CommandFilter', related_name='system_users', verbose_name=_("Command filter"), blank=True)
sftp_root = models.CharField(default='tmp', max_length=128, verbose_name=_("SFTP Root"))
token = models.TextField(default='', verbose_name=_('Token'))
home = models.CharField(max_length=4096, default='', verbose_name=_('Home'), blank=True)
system_groups = models.CharField(default='', max_length=4096, verbose_name=_('System groups'), blank=True)
_prefer = 'system_user'
def __str__(self):
@@ -140,11 +154,6 @@ class SystemUser(BaseUser):
def is_need_test_asset_connective(self):
return self.protocol not in [self.PROTOCOL_MYSQL]
def has_special_auth(self, asset=None, username=None):
if username is None and self.username_same_with_user:
raise TypeError('System user is dynamic, username should be pass')
return super().has_special_auth(asset=asset, username=username)
@property
def can_perm_to_asset(self):
return self.protocol not in [self.PROTOCOL_MYSQL]
@@ -184,3 +193,23 @@ class SystemUser(BaseUser):
ordering = ['name']
unique_together = [('name', 'org_id')]
verbose_name = _("System user")
@classmethod
def generate_fake(cls, count=10):
from random import seed
import forgery_py
from django.db import IntegrityError
seed()
for i in range(count):
obj = cls(name=forgery_py.name.full_name(),
username=forgery_py.internet.user_name(),
password=forgery_py.lorem_ipsum.word(),
comment=forgery_py.lorem_ipsum.sentence(),
created_by='Fake')
try:
obj.save()
logger.debug('Generate fake asset group: %s' % obj.name)
except IntegrityError:
print('Error continue')
continue

View File

@@ -1,39 +0,0 @@
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.request import Request
from assets.models import Node
class AssetLimitOffsetPagination(LimitOffsetPagination):
"""
需要与 `assets.api.mixin.FilterAssetByNodeMixin` 配合使用
"""
def get_count(self, queryset):
"""
1. 如果查询节点下的所有资产,那 count 使用 Node.assets_amount
2. 如果有其他过滤条件使用 super
3. 如果只查询该节点下的资产使用 super
"""
exclude_query_params = {
self.limit_query_param,
self.offset_query_param,
'node', 'all', 'show_current_asset',
'node_id', 'display', 'draw', 'fields_size',
}
for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None:
return super().get_count(queryset)
is_query_all = self._view.is_query_node_all_assets
if is_query_all:
node = self._view.node
if not node:
node = Node.org_root()
return node.assets_amount
return super().get_count(queryset)
def paginate_queryset(self, queryset, request: Request, view=None):
self._request = request
self._view = view
return super().paginate_queryset(queryset, request, view=None)

View File

@@ -67,14 +67,12 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
slug_field='name', queryset=Platform.objects.all(), label=_("Platform")
)
protocols = ProtocolsField(label=_('Protocols'), required=False)
domain_display = serializers.ReadOnlyField(source='domain.name')
admin_user_display = serializers.ReadOnlyField(source='admin_user.name')
"""
资产的数据结构
"""
class Meta:
model = Asset
list_serializer_class = AdaptedBulkListSerializer
fields_mini = ['id', 'hostname', 'ip']
fields_small = fields_mini + [
'protocol', 'port', 'protocols', 'is_active', 'public_ip',
@@ -84,7 +82,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
'created_by', 'date_created', 'hardware_info',
]
fields_fk = [
'admin_user', 'admin_user_display', 'domain', 'domain_display', 'platform'
'admin_user', 'admin_user_display', 'domain', 'platform'
]
fk_only_fields = {
'platform': ['name']
@@ -115,7 +113,6 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
queryset = queryset.select_related('admin_user', 'domain', 'platform')
queryset = queryset.prefetch_related('nodes', 'labels')
return queryset
def compatible_with_old_protocol(self, validated_data):
@@ -153,7 +150,7 @@ class AssetDisplaySerializer(AssetSerializer):
@classmethod
def setup_eager_loading(cls, queryset):
queryset = super().setup_eager_loading(queryset)
""" Perform necessary eager loading of data. """
queryset = queryset\
.annotate(admin_user_username=F('admin_user__username'))
return queryset

View File

@@ -33,15 +33,13 @@ class SystemUserSerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer):
'login_mode', 'login_mode_display',
'priority', 'username_same_with_user',
'auto_push', 'cmd_filters', 'sudo', 'shell', 'comment',
'auto_generate_key', 'sftp_root', 'token',
'assets_amount', 'date_created', 'created_by',
'home', 'system_groups'
'auto_generate_key', 'sftp_root',
'assets_amount', 'date_created', 'created_by'
]
extra_kwargs = {
'password': {"write_only": True},
'public_key': {"write_only": True},
'private_key': {"write_only": True},
'token': {"write_only": True},
'nodes_amount': {'label': _('Node')},
'assets_amount': {'label': _('Asset')},
'login_mode_display': {'label': _('Login mode display')},
@@ -145,25 +143,17 @@ class SystemUserSerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer):
class SystemUserListSerializer(SystemUserSerializer):
class Meta(SystemUserSerializer.Meta):
fields = [
'id', 'name', 'username', 'protocol',
'password', 'public_key', 'private_key',
'login_mode', 'login_mode_display',
'priority', "username_same_with_user",
'auto_push', 'sudo', 'shell', 'comment',
"assets_amount", 'home', 'system_groups',
"assets_amount",
'auto_generate_key',
'sftp_root',
]
extra_kwargs = {
'password': {"write_only": True},
'public_key': {"write_only": True},
'private_key': {"write_only": True},
}
@classmethod
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
@@ -179,7 +169,7 @@ class SystemUserWithAuthInfoSerializer(SystemUserSerializer):
'login_mode', 'login_mode_display',
'priority', 'username_same_with_user',
'auto_push', 'sudo', 'shell', 'comment',
'auto_generate_key', 'sftp_root', 'token'
'auto_generate_key', 'sftp_root',
]
extra_kwargs = {
'nodes_amount': {'label': _('Node')},
@@ -229,8 +219,15 @@ class SystemUserNodeRelationSerializer(RelationMixin, serializers.ModelSerialize
'id', 'node', "node_display",
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tree = Node.tree()
def get_node_display(self, obj):
return obj.node.full_value
if hasattr(obj, 'node_key'):
return self.tree.get_node_full_tag(obj.node_key)
else:
return obj.node.full_value
class SystemUserUserRelationSerializer(RelationMixin, serializers.ModelSerializer):

View File

@@ -1,20 +1,17 @@
# -*- coding: utf-8 -*-
#
from operator import add, sub
from assets.utils import is_asset_exists_in_node
from collections import defaultdict
from django.db.models.signals import (
post_save, m2m_changed, pre_delete, post_delete
post_save, m2m_changed, post_delete
)
from django.db.models import Q, F
from django.db.models.aggregates import Count
from django.dispatch import receiver
from common.exceptions import M2MReverseNotAllowed
from common.const.signals import PRE_ADD, POST_ADD, POST_REMOVE, PRE_CLEAR, PRE_REMOVE
from common.utils import get_logger
from common.decorator import on_transaction_commit
from .models import Asset, SystemUser, Node, compute_parent_key
from users.models import User
from orgs.utils import tmp_to_root_org
from .models import Asset, SystemUser, Node, AuthBook
from .utils import TreeService
from .tasks import (
update_assets_hardware_info_util,
test_asset_connectivity_util,
@@ -57,6 +54,15 @@ def on_asset_created_or_update(sender, instance=None, created=False, **kwargs):
instance.nodes.add(Node.org_root())
@receiver(post_delete, sender=Asset)
def on_asset_delete(sender, instance=None, **kwargs):
"""
当资产删除时,刷新节点,节点中存在节点和资产的关系
"""
logger.debug("Asset delete signal recv: {}".format(instance))
Node.refresh_assets()
@receiver(post_save, sender=SystemUser, dispatch_uid="jms")
def on_system_user_update(sender, instance=None, created=True, **kwargs):
"""
@@ -76,7 +82,7 @@ def on_system_user_assets_change(sender, instance=None, action='', model=None, p
"""
当系统用户和资产关系发生变化时,应该重新推送系统用户到新添加的资产中
"""
if action != POST_ADD:
if action != "post_add":
return
logger.debug("System user assets change signal recv: {}".format(instance))
queryset = model.objects.filter(pk__in=pk_set)
@@ -95,7 +101,7 @@ def on_system_user_users_change(sender, instance=None, action='', model=None, pk
"""
当系统用户和用户关系发生变化时,应该重新推送系统用户资产中
"""
if action != POST_ADD:
if action != "post_add":
return
if not instance.username_same_with_user:
return
@@ -114,7 +120,7 @@ def on_system_user_nodes_change(sender, instance=None, action=None, model=None,
"""
当系统用户和节点关系发生变化时,应该将节点下资产关联到新的系统用户上
"""
if action != POST_ADD:
if action != "post_add":
return
logger.info("System user nodes update signal recv: {}".format(instance))
@@ -129,217 +135,102 @@ def on_system_user_nodes_change(sender, instance=None, action=None, model=None,
@receiver(m2m_changed, sender=SystemUser.groups.through)
def on_system_user_groups_change(instance, action, pk_set, reverse, **kwargs):
def on_system_user_groups_change(sender, instance=None, action=None, model=None,
pk_set=None, reverse=False, **kwargs):
"""
当系统用户和用户组关系发生变化时,应该将组下用户关联到新的系统用户上
"""
if action != POST_ADD:
if action != "post_add" or reverse:
return
if reverse:
raise M2MReverseNotAllowed
logger.info("System user groups update signal recv: {}".format(instance))
users = User.objects.filter(groups__id__in=pk_set).distinct()
instance.users.add(*users)
groups = model.objects.filter(pk__in=pk_set).annotate(users_count=Count("users"))
users = groups.filter(users_count__gt=0).values_list('users', flat=True)
instance.users.add(*tuple(users))
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_asset_nodes_add(instance, action, reverse, pk_set, **kwargs):
def on_asset_nodes_change(sender, instance=None, action='', **kwargs):
"""
本操作共访问 4 次数据库
资产节点发生变化时,刷新节点
"""
if action.startswith('post'):
logger.debug("Asset nodes change signal recv: {}".format(instance))
Node.refresh_assets()
with tmp_to_root_org():
Node.refresh_assets()
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_asset_nodes_add(sender, instance=None, action='', model=None, pk_set=None, **kwargs):
"""
当资产的节点发生变化时,或者 当节点的资产关系发生变化时,
节点下新增的资产,添加到节点关联的系统用户中
"""
if action != POST_ADD:
if action != "post_add":
return
logger.debug("Assets node add signal recv: {}".format(action))
if reverse:
nodes = [instance.key]
asset_ids = pk_set
if model == Node:
nodes = model.objects.filter(pk__in=pk_set).values_list('key', flat=True)
assets = [instance.id]
else:
nodes = Node.objects.filter(pk__in=pk_set).values_list('key', flat=True)
asset_ids = [instance.id]
nodes = [instance.key]
assets = model.objects.filter(pk__in=pk_set).values_list('id', flat=True)
# 节点资产发生变化时,将资产关联到节点及祖先节点关联的系统用户, 只关注新增的
nodes_ancestors_keys = set()
node_tree = TreeService.new()
for node in nodes:
nodes_ancestors_keys.update(Node.get_node_ancestor_keys(node, with_self=True))
ancestors_keys = node_tree.ancestors_ids(nid=node)
nodes_ancestors_keys.update(ancestors_keys)
system_users = SystemUser.objects.filter(nodes__key__in=nodes_ancestors_keys)
# 查询所有祖先节点关联的系统用户,都是要跟资产建立关系的
system_user_ids = SystemUser.objects.filter(
nodes__key__in=nodes_ancestors_keys
).distinct().values_list('id', flat=True)
# 查询所有已存在的关系
m2m_model = SystemUser.assets.through
exist = set(m2m_model.objects.filter(
systemuser_id__in=system_user_ids, asset_id__in=asset_ids
).values_list('systemuser_id', 'asset_id'))
# TODO 优化
to_create = []
for system_user_id in system_user_ids:
asset_ids_to_push = []
for asset_id in asset_ids:
if (system_user_id, asset_id) in exist:
continue
asset_ids_to_push.append(asset_id)
to_create.append(m2m_model(
systemuser_id=system_user_id,
asset_id=asset_id
))
push_system_user_to_assets.delay(system_user_id, asset_ids_to_push)
m2m_model.objects.bulk_create(to_create)
def _update_node_assets_amount(node: Node, asset_pk_set: set, operator=add):
"""
一个节点与多个资产关系变化时,更新计数
:param node: 节点实例
:param asset_pk_set: 资产的`id`集合, 内部不会修改该值
:param operator: 操作
* -> Node
# -> Asset
* [3]
/ \
* * [2]
/ \
* * [1]
/ / \
* [a] # # [b]
"""
# 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key`
ancestor_keys = node.get_ancestor_keys(with_self=True)
ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key')
to_update = []
for ancestor in ancestors:
# 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3]
# 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的
# 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量
asset_pk_set -= set(Asset.objects.filter(
id__in=asset_pk_set
).filter(
Q(nodes__key__istartswith=f'{ancestor.key}:') |
Q(nodes__key=ancestor.key)
).distinct().values_list('id', flat=True))
if not asset_pk_set:
# 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量
# 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用
# 处理了
break
ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set))
to_update.append(ancestor)
Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key'))
def _remove_ancestor_keys(ancestor_key, tree_set):
# 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while ancestor_key and ancestor_key in tree_set:
tree_set.remove(ancestor_key)
ancestor_key = compute_parent_key(ancestor_key)
def _update_nodes_asset_amount(node_keys, asset_pk, operator):
"""
一个资产与多个节点关系变化时,更新计数
:param node_keys: 节点 id 的集合
:param asset_pk: 资产 id
:param operator: 操作
"""
# 所有相关节点的祖先节点,组成一棵局部树
ancestor_keys = set()
for key in node_keys:
ancestor_keys.update(Node.get_node_ancestor_keys(key))
# 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉
node_keys -= ancestor_keys
to_update_keys = []
for key in node_keys:
# 遍历相关节点,处理它及其祖先节点
# 查询该节点是否包含待处理资产
exists = is_asset_exists_in_node(asset_pk, key)
parent_key = compute_parent_key(key)
if exists:
# 如果资产在该节点,那么他及其祖先节点都不用处理
_remove_ancestor_keys(parent_key, ancestor_keys)
continue
else:
# 不存在,要更新本节点
to_update_keys.append(key)
# 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while parent_key and parent_key in ancestor_keys:
exists = is_asset_exists_in_node(asset_pk, parent_key)
if exists:
_remove_ancestor_keys(parent_key, ancestor_keys)
break
else:
to_update_keys.append(parent_key)
ancestor_keys.remove(parent_key)
parent_key = compute_parent_key(parent_key)
Node.objects.filter(key__in=to_update_keys).update(
assets_amount=operator(F('assets_amount'), 1)
)
system_users_assets = defaultdict(set)
for system_user in system_users:
system_users_assets[system_user].update(set(assets))
for system_user, _assets in system_users_assets.items():
system_user.assets.add(*tuple(_assets))
@receiver(m2m_changed, sender=Asset.nodes.through)
def update_nodes_assets_amount(action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
refused = (PRE_CLEAR,)
if action in refused:
raise ValueError
def on_asset_nodes_remove(sender, instance=None, action='', model=None,
pk_set=None, **kwargs):
mapper = {
PRE_ADD: add,
POST_REMOVE: sub
}
if action not in mapper:
"""
监控资产删除节点关系, 或节点删除资产,避免产生游离资产
"""
if action not in ["post_remove", "pre_clear", "post_clear"]:
return
operator = mapper[action]
if reverse:
node: Node = instance
asset_pk_set = set(pk_set)
_update_node_assets_amount(node, asset_pk_set, operator)
if action == "pre_clear":
if model == Node:
instance._nodes = list(instance.nodes.all())
else:
instance._assets = list(instance.assets.all())
return
logger.debug("Assets node remove signal recv: {}".format(action))
if action == "post_remove":
queryset = model.objects.filter(pk__in=pk_set)
else:
asset_pk = instance.id
# 与资产直接关联的节点
node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True))
_update_nodes_asset_amount(node_keys, asset_pk, operator)
if model == Node:
queryset = instance._nodes
else:
queryset = instance._assets
if model == Node:
assets = [instance]
else:
assets = queryset
if isinstance(assets, list):
assets_not_has_node = []
for asset in assets:
if asset.nodes.all().count() == 0:
assets_not_has_node.append(asset.id)
else:
assets_not_has_node = assets.annotate(nodes_count=Count('nodes'))\
.filter(nodes_count=0).values_list('id', flat=True)
Node.org_root().assets.add(*tuple(assets_not_has_node))
RELATED_NODE_IDS = '_related_node_ids'
@receiver(pre_delete, sender=Asset)
def on_asset_delete(instance: Asset, using, **kwargs):
node_ids = set(Node.objects.filter(
assets=instance
).distinct().values_list('id', flat=True))
setattr(instance, RELATED_NODE_IDS, node_ids)
m2m_changed.send(
sender=Asset.nodes.through, instance=instance, reverse=False,
model=Node, pk_set=node_ids, using=using, action=PRE_REMOVE
)
@receiver(post_delete, sender=Asset)
def on_asset_post_delete(instance: Asset, using, **kwargs):
node_ids = getattr(instance, RELATED_NODE_IDS, None)
if node_ids:
m2m_changed.send(
sender=Asset.nodes.through, instance=instance, reverse=False,
model=Node, pk_set=node_ids, using=using, action=POST_REMOVE
)
@receiver([post_save, post_delete], sender=Node)
def on_node_update_or_created(sender, **kwargs):
# 刷新节点
Node.refresh_nodes()
with tmp_to_root_org():
Node.refresh_nodes()

View File

@@ -9,4 +9,3 @@ from .gather_asset_users import *
from .gather_asset_hardware_info import *
from .push_system_user import *
from .system_user_connectivity import *
from .nodes_amount import *

View File

@@ -85,7 +85,7 @@ def test_asset_user_connectivity_util(asset_user, task_name):
raw, summary = test_user_connectivity(
task_name=task_name, asset=asset_user.asset,
username=asset_user.username, password=asset_user.password,
private_key=asset_user.private_key_file
private_key=asset_user.private_key
)
except Exception as e:
logger.warn("Failed run adhoc {}, {}".format(task_name, e))

View File

@@ -64,7 +64,7 @@ GATHER_ASSET_USERS_TASKS = [
"action": {
"module": "shell",
"args": "users=$(getent passwd | grep -v 'nologin' | "
"grep -v 'shudown' | awk -F: '{ print $1 }');for i in $users;do last -w -F $i -1 | "
"grep -v 'shudown' | awk -F: '{ print $1 }');for i in $users;do last -F $i -1 | "
"head -1 | grep -v '^$' | awk '{ print $1\"@\"$3\"@\"$5,$6,$7,$8 }';done"
}
}

View File

@@ -129,5 +129,5 @@ def update_assets_hardware_info_period():
def update_node_assets_hardware_info_manual(node):
task_name = _("Update node asset hardware information: {}").format(node.name)
assets = node.get_all_assets()
result = update_assets_hardware_info_util(assets, task_name=task_name)
result = update_assets_hardware_info_util.delay(assets, task_name=task_name)
return result

View File

@@ -92,7 +92,7 @@ def add_asset_users(assets, results):
for username, data in users.items():
defaults = {'asset': asset, 'username': username, 'present': True}
if data.get("ip"):
defaults["ip_last_login"] = data["ip"][:32]
defaults["ip_last_login"] = data["ip"]
if data.get("date"):
defaults["date_last_login"] = data["date"]
GatheredUser.objects.update_or_create(

View File

@@ -1,14 +0,0 @@
from celery import shared_task
from assets.utils import check_node_assets_amount
from common.utils import get_logger
from common.utils.timezone import now
logger = get_logger(__file__)
@shared_task()
def check_node_assets_amount_celery_task():
logger.info(f'>>> {now()} begin check_node_assets_amount_celery_task ...')
check_node_assets_amount()
logger.info(f'>>> {now()} end check_node_assets_amount_celery_task ...')

View File

@@ -2,13 +2,10 @@
from itertools import groupby
from celery import shared_task
from common.db.utils import get_object_if_need, get_objects_if_need
from django.utils.translation import ugettext as _
from django.db.models import Empty
from common.utils import encrypt_password, get_logger
from assets.models import SystemUser, Asset
from orgs.utils import org_aware_func
from orgs.utils import tmp_to_org, org_aware_func
from . import const
from .utils import clean_ansible_task_hosts, group_asset_by_platform
@@ -20,42 +17,20 @@ __all__ = [
]
def _split_by_comma(raw: str):
try:
return [i.strip() for i in raw.split(',')]
except AttributeError:
return []
def _dump_args(args: dict):
return ' '.join([f'{k}={v}' for k, v in args.items() if v is not Empty])
def get_push_unixlike_system_user_tasks(system_user, username=None):
if username is None:
username = system_user.username
password = system_user.password
public_key = system_user.public_key
groups = _split_by_comma(system_user.system_groups)
if groups:
groups = '"%s"' % ','.join(groups)
add_user_args = {
'name': username,
'shell': system_user.shell or Empty,
'state': 'present',
'home': system_user.home or Empty,
'groups': groups or Empty
}
tasks = [
{
'name': 'Add user {}'.format(username),
'action': {
'module': 'user',
'args': _dump_args(add_user_args),
'args': 'name={} shell={} state=present'.format(
username, system_user.shell or '/bin/bash',
),
}
},
{
@@ -127,14 +102,8 @@ def get_push_windows_system_user_tasks(system_user, username=None):
if username is None:
username = system_user.username
password = system_user.password
groups = {'Users', 'Remote Desktop Users'}
if system_user.system_groups:
groups.update(_split_by_comma(system_user.system_groups))
groups = ','.join(groups)
tasks = []
if not password:
logger.error("Error: no password found")
return tasks
task = {
'name': 'Add user {}'.format(username),
@@ -147,9 +116,9 @@ def get_push_windows_system_user_tasks(system_user, username=None):
'update_password=always '
'password_expired=no '
'password_never_expires=yes '
'groups="{}" '
'groups="Users,Remote Desktop Users" '
'groups_action=add '
''.format(username, username, password, groups),
''.format(username, username, password),
}
}
tasks.append(task)
@@ -210,22 +179,20 @@ def push_system_user_util(system_user, assets, task_name, username=None):
print(_("Start push system user for platform: [{}]").format(platform))
print(_("Hosts count: {}").format(len(_hosts)))
# 如果没有特殊密码设置,就不需要单独推送某台机器了
if not system_user.has_special_auth(username=username):
if not system_user.has_special_auth():
logger.debug("System user not has special auth")
tasks = get_push_system_user_tasks(system_user, platform, username=username)
run_task(tasks, _hosts)
continue
for _host in _hosts:
system_user.load_asset_special_auth(_host, username=username)
system_user.load_asset_special_auth(_host)
tasks = get_push_system_user_tasks(system_user, platform, username=username)
run_task(tasks, [_host])
@shared_task(queue="ansible")
def push_system_user_to_assets_manual(system_user, username=None):
system_user = get_object_if_need(SystemUser, system_user)
assets = system_user.get_related_assets()
task_name = _("Push system users to assets: {}").format(system_user.name)
return push_system_user_util(system_user, assets, task_name=task_name, username=username)
@@ -243,11 +210,11 @@ def push_system_user_a_asset_manual(system_user, asset, username=None):
@shared_task(queue="ansible")
def push_system_user_to_assets(system_user, assets, username=None):
system_user = get_object_if_need(SystemUser, system_user)
task_name = _("Push system users to assets: {}").format(system_user.name)
assets = get_objects_if_need(Asset, assets)
return push_system_user_util(system_user, assets, task_name, username=username)
# @shared_task
# @register_as_period_task(interval=3600)
# @after_app_ready_start

View File

@@ -31,10 +31,7 @@ def test_system_user_connectivity_util(system_user, assets, task_name):
"""
from ops.utils import update_or_create_ansible_task
# hosts = clean_ansible_task_hosts(assets, system_user=system_user)
# TODO: 这里不传递系统用户因为clean_ansible_task_hosts会通过system_user来判断是否可以推送
# 不符合测试可连接性逻辑, 后面需要优化此逻辑
hosts = clean_ansible_task_hosts(assets)
hosts = clean_ansible_task_hosts(assets, system_user=system_user)
if not hosts:
return {}
platform_hosts_map = {}

View File

@@ -1,8 +1,8 @@
# coding:utf-8
from django.urls import path, re_path
from rest_framework_nested import routers
# from rest_framework.routers import DefaultRouter
from rest_framework_bulk.routes import BulkRouter
from django.db.transaction import non_atomic_requests
from common import api as capi
@@ -55,9 +55,9 @@ urlpatterns = [
path('nodes/children/', api.NodeChildrenApi.as_view(), name='node-children-2'),
path('nodes/<uuid:pk>/children/add/', api.NodeAddChildrenApi.as_view(), name='node-add-children'),
path('nodes/<uuid:pk>/assets/', api.NodeAssetsApi.as_view(), name='node-assets'),
path('nodes/<uuid:pk>/assets/add/', non_atomic_requests(api.NodeAddAssetsApi.as_view()), name='node-add-assets'),
path('nodes/<uuid:pk>/assets/replace/', non_atomic_requests(api.MoveAssetsToNodeApi.as_view()), name='node-replace-assets'),
path('nodes/<uuid:pk>/assets/remove/', non_atomic_requests(api.NodeRemoveAssetsApi.as_view()), name='node-remove-assets'),
path('nodes/<uuid:pk>/assets/add/', api.NodeAddAssetsApi.as_view(), name='node-add-assets'),
path('nodes/<uuid:pk>/assets/replace/', api.NodeReplaceAssetsApi.as_view(), name='node-replace-assets'),
path('nodes/<uuid:pk>/assets/remove/', api.NodeRemoveAssetsApi.as_view(), name='node-remove-assets'),
path('nodes/<uuid:pk>/tasks/', api.NodeTaskCreateApi.as_view(), name='node-task-create'),
path('gateways/<uuid:pk>/test-connective/', api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'),

View File

@@ -1,52 +1,195 @@
# ~*~ coding: utf-8 ~*~
#
from django.db.models import Q
from treelib import Tree
from treelib.exceptions import NodeIDAbsentError
from collections import defaultdict
from copy import deepcopy
from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none
from common.http import is_true
from common.utils import get_logger, timeit, lazyproperty
from .models import Asset, Node
logger = get_logger(__file__)
def check_node_assets_amount():
for node in Node.objects.all():
assets_amount = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes=node)
).distinct().count()
class TreeService(Tree):
tag_sep = ' / '
if node.assets_amount != assets_amount:
print(f'>>> <Node:{node.key}> wrong assets amount '
f'{node.assets_amount} right is {assets_amount}')
node.assets_amount = assets_amount
node.save()
@staticmethod
@timeit
def get_nodes_assets_map():
nodes_assets_map = defaultdict(set)
asset_node_list = Node.assets.through.objects.values_list(
'asset', 'node__key'
)
for asset_id, key in asset_node_list:
nodes_assets_map[key].add(asset_id)
return nodes_assets_map
@classmethod
@timeit
def new(cls):
from .models import Node
all_nodes = list(Node.objects.all().values("key", "value"))
all_nodes.sort(key=lambda x: len(x["key"].split(":")))
tree = cls()
tree.create_node(tag='', identifier='', data={})
for node in all_nodes:
key = node["key"]
value = node["value"]
parent_key = ":".join(key.split(":")[:-1])
tree.safe_create_node(
tag=value, identifier=key,
parent=parent_key,
)
tree.init_assets()
return tree
def is_asset_exists_in_node(asset_pk, node_key):
return Asset.objects.filter(
id=asset_pk
).filter(
Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key)
).exists()
def init_assets(self):
node_assets_map = self.get_nodes_assets_map()
for node in self.all_nodes_itr():
key = node.identifier
assets = node_assets_map.get(key, set())
data = {"assets": assets, "all_assets": None}
node.data = data
def safe_create_node(self, **kwargs):
parent = kwargs.get("parent")
if not self.contains(parent):
kwargs['parent'] = self.root
self.create_node(**kwargs)
def is_query_node_all_assets(request):
request = request
query_all_arg = request.query_params.get('all')
show_current_asset_arg = request.query_params.get('show_current_asset')
if show_current_asset_arg is not None:
return not is_true(show_current_asset_arg)
return is_true(query_all_arg)
def all_children_ids(self, nid, with_self=True):
children_ids = self.expand_tree(nid)
if not with_self:
next(children_ids)
return list(children_ids)
def all_children(self, nid, with_self=True, deep=False):
children_ids = self.all_children_ids(nid, with_self=with_self)
return [self.get_node(i, deep=deep) for i in children_ids]
def get_node(request):
node_id = dict_get_any(request.query_params, ['node', 'node_id'])
if not node_id:
return None
def ancestors_ids(self, nid, with_self=True):
ancestor_ids = list(self.rsearch(nid))
ancestor_ids.pop()
if not with_self:
ancestor_ids.pop(0)
return ancestor_ids
if is_uuid(node_id):
node = get_object_or_none(Node, id=node_id)
else:
node = get_object_or_none(Node, key=node_id)
return node
def ancestors(self, nid, with_self=False, deep=False):
ancestor_ids = self.ancestors_ids(nid, with_self=with_self)
ancestors = [self.get_node(i, deep=deep) for i in ancestor_ids]
return ancestors
def get_node_full_tag(self, nid):
ancestors = self.ancestors(nid, with_self=True)
ancestors.reverse()
return self.tag_sep.join([n.tag for n in ancestors])
def get_family(self, nid, deep=False):
ancestors = self.ancestors(nid, with_self=False, deep=deep)
children = self.all_children(nid, with_self=False)
return ancestors + [self[nid]] + children
@staticmethod
def is_parent(child, parent):
parent_id = child.bpointer
return parent_id == parent.identifier
def root_node(self):
return self.get_node(self.root)
def get_node(self, nid, deep=False):
node = super().get_node(nid)
if deep:
node = self.copy_node(node)
node.data = {}
return node
def parent(self, nid, deep=False):
parent = super().parent(nid)
if deep:
parent = self.copy_node(parent)
return parent
@lazyproperty
def invalid_assets(self):
assets = Asset.objects.filter(is_active=False).values_list('id', flat=True)
return assets
def set_assets(self, nid, assets):
node = self.get_node(nid)
if node.data is None:
node.data = {}
node.data["assets"] = assets
def assets(self, nid):
node = self.get_node(nid)
return node.data.get("assets", set())
def valid_assets(self, nid):
return set(self.assets(nid)) - set(self.invalid_assets)
def all_assets(self, nid):
node = self.get_node(nid)
if node.data is None:
node.data = {}
all_assets = node.data.get("all_assets")
if all_assets is not None:
return all_assets
all_assets = set(self.assets(nid))
try:
children = self.children(nid)
except NodeIDAbsentError:
children = []
for child in children:
all_assets.update(self.all_assets(child.identifier))
node.data["all_assets"] = all_assets
return all_assets
def all_valid_assets(self, nid):
return set(self.all_assets(nid)) - set(self.invalid_assets)
def assets_amount(self, nid):
return len(self.all_assets(nid))
def valid_assets_amount(self, nid):
return len(self.all_valid_assets(nid))
@staticmethod
def copy_node(node):
new_node = deepcopy(node)
new_node.fpointer = None
return new_node
def safe_add_ancestors(self, node, ancestors):
# 如果没有祖先节点,那么添加该节点, 父节点是root node
if len(ancestors) == 0:
parent = self.root_node()
else:
parent = ancestors[0]
# 如果当前节点已再树中,则移动当前节点到父节点中
# 这个是由于 当前节点放到了二级节点中
if not self.contains(parent.identifier):
# logger.debug('Add parent: {}'.format(parent.identifier))
self.safe_add_ancestors(parent, ancestors[1:])
if self.contains(node.identifier):
# msg = 'Move node to parent: {} => {}'.format(
# node.identifier, parent.identifier
# )
# logger.debug(msg)
self.move_node(node.identifier, parent.identifier)
else:
# logger.debug('Add node: {}'.format(node.identifier))
self.add_node(node, parent)
#
# def __getstate__(self):
# self.mutex = None
# self.all_nodes_assets_map = {}
# self.nodes_assets_map = {}
# return self.__dict__
# def __setstate__(self, state):
# self.__dict__ = state

View File

@@ -43,7 +43,7 @@ class UserLoginLogViewSet(ListModelMixin, CommonGenericViewSet):
@staticmethod
def get_org_members():
users = current_org.get_members().values_list('username', flat=True)
users = current_org.get_org_members().values_list('username', flat=True)
return users
def get_queryset(self):
@@ -79,7 +79,7 @@ class PasswordChangeLogViewSet(ListModelMixin, CommonGenericViewSet):
ordering = ['-datetime']
def get_queryset(self):
users = current_org.get_members()
users = current_org.get_org_members()
queryset = super().get_queryset().filter(
user__in=[user.__str__() for user in users]
)
@@ -107,7 +107,7 @@ class CommandExecutionViewSet(ListModelMixin, OrgGenericViewSet):
class CommandExecutionHostRelationViewSet(OrgRelationMixin, OrgBulkModelViewSet):
serializer_class = CommandExecutionHostsRelationSerializer
m2m_field = CommandExecution.hosts.field
permission_classes = [IsOrgAdmin | IsOrgAuditor]
permission_classes = (IsOrgAdmin,)
filter_fields = [
'id', 'asset', 'commandexecution'
]

View File

@@ -20,7 +20,7 @@ class CurrentOrgMembersFilter(filters.BaseFilterBackend):
]
def _get_user_list(self):
users = current_org.get_members(exclude=('Auditor',))
users = current_org.get_org_members(exclude=('Auditor',))
return users
def filter_queryset(self, request, queryset, view):

View File

@@ -1,18 +0,0 @@
# Generated by Django 2.2.10 on 2020-06-24 08:54
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('audits', '0008_auto_20200508_2105'),
]
operations = [
migrations.AlterField(
model_name='ftplog',
name='operate',
field=models.CharField(choices=[('Delete', 'Delete'), ('Upload', 'Upload'), ('Download', 'Download'), ('Rmdir', 'Rmdir'), ('Rename', 'Rename'), ('Mkdir', 'Mkdir'), ('Symlink', 'Symlink')], max_length=16, verbose_name='Operate'),
),
]

View File

@@ -1,18 +0,0 @@
# Generated by Django 2.2.13 on 2020-08-11 03:22
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('audits', '0009_auto_20200624_1654'),
]
operations = [
migrations.AlterField(
model_name='operatelog',
name='datetime',
field=models.DateTimeField(auto_now=True, db_index=True, verbose_name='Datetime'),
),
]

View File

@@ -14,30 +14,12 @@ __all__ = [
class FTPLog(OrgModelMixin):
OPERATE_DELETE = 'Delete'
OPERATE_UPLOAD = 'Upload'
OPERATE_DOWNLOAD = 'Download'
OPERATE_RMDIR = 'Rmdir'
OPERATE_RENAME = 'Rename'
OPERATE_MKDIR = 'Mkdir'
OPERATE_SYMLINK = 'Symlink'
OPERATE_CHOICES = (
(OPERATE_DELETE, _('Delete')),
(OPERATE_UPLOAD, _('Upload')),
(OPERATE_DOWNLOAD, _('Download')),
(OPERATE_RMDIR, _('Rmdir')),
(OPERATE_RENAME, _('Rename')),
(OPERATE_MKDIR, _('Mkdir')),
(OPERATE_SYMLINK, _('Symlink'))
)
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
user = models.CharField(max_length=128, verbose_name=_('User'))
remote_addr = models.CharField(max_length=128, verbose_name=_("Remote addr"), blank=True, null=True)
asset = models.CharField(max_length=1024, verbose_name=_("Asset"))
system_user = models.CharField(max_length=128, verbose_name=_("System user"))
operate = models.CharField(max_length=16, verbose_name=_("Operate"), choices=OPERATE_CHOICES)
operate = models.CharField(max_length=16, verbose_name=_("Operate"))
filename = models.CharField(max_length=1024, verbose_name=_("Filename"))
is_success = models.BooleanField(default=True, verbose_name=_("Success"))
date_start = models.DateTimeField(auto_now_add=True, verbose_name=_('Date start'))
@@ -58,7 +40,7 @@ class OperateLog(OrgModelMixin):
resource_type = models.CharField(max_length=64, verbose_name=_("Resource Type"))
resource = models.CharField(max_length=128, verbose_name=_("Resource"))
remote_addr = models.CharField(max_length=128, verbose_name=_("Remote addr"), blank=True, null=True)
datetime = models.DateTimeField(auto_now=True, verbose_name=_('Datetime'), db_index=True)
datetime = models.DateTimeField(auto_now=True, verbose_name=_('Datetime'))
def __str__(self):
return "<{}> {} <{}>".format(self.user, self.action, self.resource)
@@ -124,7 +106,7 @@ class UserLoginLog(models.Model):
Q(username__contains=keyword)
)
if not current_org.is_root():
username_list = current_org.get_members().values_list('username', flat=True)
username_list = current_org.get_org_members().values_list('username', flat=True)
login_logs = login_logs.filter(username__in=username_list)
return login_logs

View File

@@ -12,13 +12,12 @@ from . import models
class FTPLogSerializer(serializers.ModelSerializer):
operate_display = serializers.ReadOnlyField(source='get_operate_display')
class Meta:
model = models.FTPLog
fields = (
'id', 'user', 'remote_addr', 'asset', 'system_user', 'org_id',
'operate', 'filename', 'is_success', 'date_start', 'operate_display'
'id', 'user', 'remote_addr', 'asset', 'system_user',
'operate', 'filename', 'is_success', 'date_start'
)
@@ -40,7 +39,7 @@ class OperateLogSerializer(serializers.ModelSerializer):
model = models.OperateLog
fields = (
'id', 'user', 'action', 'resource_type', 'resource',
'remote_addr', 'datetime', 'org_id'
'remote_addr', 'datetime'
)
@@ -66,7 +65,7 @@ class CommandExecutionSerializer(serializers.ModelSerializer):
fields_mini = ['id']
fields_small = fields_mini + [
'run_as', 'command', 'user', 'is_finished',
'date_start', 'result', 'is_success', 'org_id'
'date_start', 'result', 'is_success'
]
fields = fields_small + ['hosts', 'run_as_display', 'user_display']
extra_kwargs = {

View File

@@ -16,7 +16,7 @@ def clean_login_log_period():
try:
days = int(settings.LOGIN_LOG_KEEP_DAYS)
except ValueError:
days = 9999
days = 90
expired_day = now - datetime.timedelta(days=days)
UserLoginLog.objects.filter(datetime__lt=expired_day).delete()
@@ -28,6 +28,6 @@ def clean_operation_log_period():
try:
days = int(settings.LOGIN_LOG_KEEP_DAYS)
except ValueError:
days = 9999
days = 90
expired_day = now - datetime.timedelta(days=days)
OperateLog.objects.filter(datetime__lt=expired_day).delete()

View File

@@ -6,4 +6,3 @@ from .token import *
from .mfa import *
from .access_key import *
from .login_confirm import *
from .sso import *

View File

@@ -54,3 +54,12 @@ class UserConnectionTokenApi(RootOrgViewMixin, APIView):
return Response(value)
else:
return Response({'user': value['user']})
def get_permissions(self):
if self.request.query_params.get('user-only', None):
self.permission_classes = (AllowAny,)
return super().get_permissions()

View File

@@ -34,6 +34,16 @@ class LoginConfirmSettingUpdateApi(UpdateAPIView):
class TicketStatusApi(mixins.AuthMixin, APIView):
permission_classes = ()
def get_ticket(self):
from tickets.models import Ticket
ticket_id = self.request.session.get("auth_ticket_id")
logger.debug('Login confirm ticket id: {}'.format(ticket_id))
if not ticket_id:
ticket = None
else:
ticket = get_object_or_none(Ticket, pk=ticket_id)
return ticket
def get(self, request, *args, **kwargs):
try:
self.check_user_login_confirm()

View File

@@ -1,86 +0,0 @@
from uuid import UUID
from urllib.parse import urlencode
from django.contrib.auth import login
from django.conf import settings
from django.http.response import HttpResponseRedirect
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.request import Request
from common.utils.timezone import utcnow
from common.const.http import POST, GET
from common.drf.api import JmsGenericViewSet
from common.drf.serializers import EmptySerializer
from common.permissions import IsSuperUser
from common.utils import reverse
from users.models import User
from ..serializers import SSOTokenSerializer
from ..models import SSOToken
from ..filters import AuthKeyQueryDeclaration
from ..mixins import AuthMixin
from ..errors import SSOAuthClosed
NEXT_URL = 'next'
AUTH_KEY = 'authkey'
class SSOViewSet(AuthMixin, JmsGenericViewSet):
queryset = SSOToken.objects.all()
serializer_classes = {
'login_url': SSOTokenSerializer,
'login': EmptySerializer
}
@action(methods=[POST], detail=False, permission_classes=[IsSuperUser], url_path='login-url')
def login_url(self, request, *args, **kwargs):
if not settings.AUTH_SSO:
raise SSOAuthClosed()
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
username = serializer.validated_data['username']
user = User.objects.get(username=username)
next_url = serializer.validated_data.get(NEXT_URL)
operator = request.user.username
# TODO `created_by` 和 `created_by` 可以通过 `ThreadLocal` 统一处理
token = SSOToken.objects.create(user=user, created_by=operator, updated_by=operator)
query = {
AUTH_KEY: token.authkey,
NEXT_URL: next_url or ''
}
login_url = '%s?%s' % (reverse('api-auth:sso-login', external=True), urlencode(query))
return Response(data={'login_url': login_url})
@action(methods=[GET], detail=False, filter_backends=[AuthKeyQueryDeclaration], permission_classes=[])
def login(self, request: Request, *args, **kwargs):
"""
此接口违反了 `Restful` 的规范
`GET` 应该是安全的方法,但此接口是不安全的
"""
authkey = request.query_params.get(AUTH_KEY)
next_url = request.query_params.get(NEXT_URL)
if not next_url or not next_url.startswith('/'):
next_url = reverse('index')
try:
authkey = UUID(authkey)
token = SSOToken.objects.get(authkey=authkey, expired=False)
# 先过期,只能访问这一次
token.expired = True
token.save()
except (ValueError, SSOToken.DoesNotExist):
self.send_auth_signal(success=False, reason='authkey_invalid')
return HttpResponseRedirect(next_url)
# 判断是否过期
if (utcnow().timestamp() - token.date_created.timestamp()) > settings.AUTH_SSO_AUTHKEY_TTL:
self.send_auth_signal(success=False, reason='authkey_timeout')
return HttpResponseRedirect(next_url)
user = token.user
login(self.request, user, 'authentication.backends.api.SSOAuthentication')
self.send_auth_signal(success=True, user=user)
return HttpResponseRedirect(next_url)

View File

@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
#
from django.shortcuts import redirect
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework.generics import CreateAPIView
@@ -40,5 +40,3 @@ class TokenCreateApi(AuthMixin, CreateAPIView):
return Response(e.as_data(), status=400)
except errors.NeedMoreInfoError as e:
return Response(e.as_data(), status=200)
except errors.PasswdTooSimple as e:
return redirect(e.url)

View File

@@ -5,13 +5,14 @@ import uuid
import time
from django.core.cache import cache
from django.conf import settings
from django.utils.translation import ugettext as _
from django.utils.six import text_type
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import authentication, exceptions
from common.auth import signature
from rest_framework.authentication import CSRFCheck
from common.utils import get_object_or_none, make_signature, http_to_unixtime
from ..models import AccessKey, PrivateToken
@@ -196,12 +197,3 @@ class SignatureAuthentication(signature.SignatureAuthentication):
return user, secret
except AccessKey.DoesNotExist:
return None, None
class SSOAuthentication(ModelBackend):
"""
什么也不做呀😺
"""
def authenticate(self, request, sso_token=None, **kwargs):
pass

View File

@@ -1,10 +0,0 @@
from django_cas_ng.middleware import CASMiddleware as _CASMiddleware
from django.core.exceptions import MiddlewareNotUsed
from django.conf import settings
class CASMiddleware(_CASMiddleware):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not settings.AUTH_CAS:
raise MiddlewareNotUsed

View File

@@ -1,10 +0,0 @@
from jms_oidc_rp.middleware import OIDCRefreshIDTokenMiddleware as _OIDCRefreshIDTokenMiddleware
from django.core.exceptions import MiddlewareNotUsed
from django.conf import settings
class OIDCRefreshIDTokenMiddleware(_OIDCRefreshIDTokenMiddleware):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not settings.AUTH_OPENID:
raise MiddlewareNotUsed

View File

@@ -1,17 +1,17 @@
# -*- coding: utf-8 -*-
#
import traceback
from django.contrib.auth import get_user_model, authenticate
from django.contrib.auth import get_user_model
from radiusauth.backends import RADIUSBackend, RADIUSRealmBackend
from django.conf import settings
from pyrad.packet import AccessRequest
User = get_user_model()
class CreateUserMixin:
def get_django_user(self, username, password=None, *args, **kwargs):
def get_django_user(self, username, password=None):
if isinstance(username, bytes):
username = username.decode()
try:
@@ -27,23 +27,10 @@ class CreateUserMixin:
user.save()
return user
def _perform_radius_auth(self, client, packet):
# TODO: 等待官方库修复这个BUG
try:
return super()._perform_radius_auth(client, packet)
except UnicodeError as e:
import sys
tb = ''.join(traceback.format_exception(*sys.exc_info(), limit=2, chain=False))
if tb.find("cl.decode") != -1:
return [], False, False
return None
class RadiusBackend(CreateUserMixin, RADIUSBackend):
def authenticate(self, request, username='', password='', **kwargs):
return super().authenticate(request, username=username, password=password)
pass
class RadiusRealmBackend(CreateUserMixin, RADIUSRealmBackend):
def authenticate(self, request, username='', password='', realm=None, **kwargs):
return super().authenticate(request, username=username, password=password, realm=realm)
pass

View File

@@ -1,2 +0,0 @@
RSA_PRIVATE_KEY = 'rsa_private_key'
RSA_PUBLIC_KEY = 'rsa_public_key'

View File

@@ -4,14 +4,12 @@ from django.utils.translation import ugettext_lazy as _
from django.urls import reverse
from django.conf import settings
from common.exceptions import JMSException
from .signals import post_auth_failed
from users.utils import (
increase_login_failed_count, get_login_failed_count
)
reason_password_failed = 'password_failed'
reason_password_decrypt_failed = 'password_decrypt_failed'
reason_mfa_failed = 'mfa_failed'
reason_mfa_unset = 'mfa_unset'
reason_user_not_exist = 'user_not_exist'
@@ -21,7 +19,6 @@ reason_user_inactive = 'user_inactive'
reason_choices = {
reason_password_failed: _('Username/password check failed'),
reason_password_decrypt_failed: _('Password decrypt failed'),
reason_mfa_failed: _('MFA failed'),
reason_mfa_unset: _('MFA unset'),
reason_user_not_exist: _("Username does not exist"),
@@ -206,17 +203,3 @@ class LoginConfirmOtherError(LoginConfirmBaseError):
def __init__(self, ticket_id, status):
msg = login_confirm_error_msg.format(status)
super().__init__(ticket_id=ticket_id, msg=msg)
class SSOAuthClosed(JMSException):
default_code = 'sso_auth_closed'
default_detail = _('SSO auth closed')
class PasswdTooSimple(JMSException):
default_code = 'passwd_too_simple'
default_detail = _('Your password is too simple, please change it for security')
def __init__(self, url, *args, **kwargs):
super(PasswdTooSimple, self).__init__(*args, **kwargs)
self.url = url

View File

@@ -1,15 +0,0 @@
from rest_framework import filters
from rest_framework.compat import coreapi, coreschema
class AuthKeyQueryDeclaration(filters.BaseFilterBackend):
def get_schema_fields(self, view):
return [
coreapi.Field(
name='authkey', location='query', required=True, type='string',
schema=coreschema.String(
title='authkey',
description='authkey'
)
)
]

View File

@@ -2,7 +2,6 @@
#
from django import forms
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from captcha.fields import CaptchaField
@@ -11,7 +10,7 @@ class UserLoginForm(forms.Form):
username = forms.CharField(label=_('Username'), max_length=100)
password = forms.CharField(
label=_('Password'), widget=forms.PasswordInput,
max_length=1024, strip=False
max_length=128, strip=False
)
def confirm_login_allowed(self, user):
@@ -22,24 +21,9 @@ class UserLoginForm(forms.Form):
)
class UserCheckOtpCodeForm(forms.Form):
otp_code = forms.CharField(label=_('MFA code'), max_length=6)
class CaptchaMixin(forms.Form):
class UserLoginCaptchaForm(UserLoginForm):
captcha = CaptchaField()
class ChallengeMixin(forms.Form):
challenge = forms.CharField(label=_('MFA code'), max_length=6,
required=False)
def get_user_login_form_cls(*, captcha=False):
bases = []
if settings.SECURITY_LOGIN_CAPTCHA_ENABLED and captcha:
bases.append(CaptchaMixin)
if settings.SECURITY_LOGIN_CHALLENGE_ENABLED:
bases.append(ChallengeMixin)
bases.append(UserLoginForm)
return type('UserLoginForm', tuple(bases), {})
class UserCheckOtpCodeForm(forms.Form):
otp_code = forms.CharField(label=_('MFA code'), max_length=6)

View File

@@ -1,32 +0,0 @@
# Generated by Django 2.2.10 on 2020-07-31 08:36
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import uuid
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('authentication', '0003_loginconfirmsetting'),
]
operations = [
migrations.CreateModel(
name='SSOToken',
fields=[
('created_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Created by')),
('updated_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Updated by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('authkey', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False, verbose_name='Token')),
('expired', models.BooleanField(default=False, verbose_name='Expired')),
('user', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.PROTECT, to=settings.AUTH_USER_MODEL, verbose_name='User')),
],
options={
'abstract': False,
},
),
]

View File

@@ -1,13 +1,7 @@
# -*- coding: utf-8 -*-
#
from urllib.parse import urlencode
from functools import partial
import time
from django.conf import settings
from django.contrib.auth import authenticate
from django.shortcuts import reverse
from django.contrib.auth import BACKEND_SESSION_KEY
from common.utils import get_object_or_none, get_request_ip, get_logger
from users.models import User
@@ -15,9 +9,8 @@ from users.utils import (
is_block_login, clean_failed_count
)
from . import errors
from .utils import rsa_decrypt
from .utils import check_user_valid
from .signals import post_auth_success, post_auth_failed
from .const import RSA_PRIVATE_KEY
logger = get_logger(__name__)
@@ -28,14 +21,8 @@ class AuthMixin:
def get_user_from_session(self):
if self.request.session.is_empty():
raise errors.SessionEmptyError()
if all((self.request.user,
not self.request.user.is_anonymous,
BACKEND_SESSION_KEY in self.request.session)):
user = self.request.user
user.backend = self.request.session[BACKEND_SESSION_KEY]
return user
if self.request.user and not self.request.user.is_anonymous:
return self.request.user
user_id = self.request.session.get('user_id')
if not user_id:
user = None
@@ -63,54 +50,25 @@ class AuthMixin:
logger.warn('Ip was blocked' + ': ' + username + ':' + ip)
raise errors.BlockLoginError(username=username, ip=ip)
def decrypt_passwd(self, raw_passwd):
# 获取解密密钥,对密码进行解密
rsa_private_key = self.request.session.get(RSA_PRIVATE_KEY)
if rsa_private_key is not None:
try:
return rsa_decrypt(raw_passwd, rsa_private_key)
except Exception as e:
logger.error(e, exc_info=True)
logger.error(f'Decrypt password faild: password[{raw_passwd}] rsa_private_key[{rsa_private_key}]')
return None
return raw_passwd
def check_user_auth(self, decrypt_passwd=False):
def check_user_auth(self):
self.check_is_block()
request = self.request
if hasattr(request, 'data'):
data = request.data
username = request.data.get('username', '')
password = request.data.get('password', '')
public_key = request.data.get('public_key', '')
else:
data = request.POST
username = data.get('username', '')
password = data.get('password', '')
challenge = data.get('challenge', '')
public_key = data.get('public_key', '')
username = request.POST.get('username', '')
password = request.POST.get('password', '')
public_key = request.POST.get('public_key', '')
user, error = check_user_valid(
request=request, username=username, password=password, public_key=public_key
)
ip = self.get_request_ip()
CredentialError = partial(errors.CredentialError, username=username, ip=ip, request=request)
if decrypt_passwd:
password = self.decrypt_passwd(password)
if not password:
raise CredentialError(error=errors.reason_password_decrypt_failed)
user = authenticate(request,
username=username,
password=password + challenge.strip(),
public_key=public_key)
if not user:
raise CredentialError(error=errors.reason_password_failed)
elif user.is_expired:
raise CredentialError(error=errors.reason_user_inactive)
elif not user.is_active:
raise CredentialError(error=errors.reason_user_inactive)
elif user.password_has_expired:
raise CredentialError(error=errors.reason_password_expired)
self._check_passwd_is_too_simple(user, password)
raise errors.CredentialError(
username=username, error=error, ip=ip, request=request
)
clean_failed_count(username, ip)
request.session['auth_password'] = 1
request.session['user_id'] = str(user.id)
@@ -118,30 +76,14 @@ class AuthMixin:
request.session['auth_backend'] = auth_backend
return user
@classmethod
def _check_passwd_is_too_simple(cls, user, password):
if user.is_superuser and password == 'admin':
reset_passwd_url = reverse('authentication:reset-password')
query_str = urlencode({
'token': user.generate_reset_token()
})
reset_passwd_url = f'{reset_passwd_url}?{query_str}'
flash_page_url = reverse('authentication:passwd-too-simple-flash-msg')
query_str = urlencode({
'redirect_url': reset_passwd_url
})
raise errors.PasswdTooSimple(f'{flash_page_url}?{query_str}')
def check_user_auth_if_need(self, decrypt_passwd=False):
def check_user_auth_if_need(self):
request = self.request
if request.session.get('auth_password') and \
request.session.get('user_id'):
user = self.get_user_from_session()
if user:
return user
return self.check_user_auth(decrypt_passwd=decrypt_passwd)
return self.check_user_auth()
def check_user_mfa_if_need(self, user):
if self.request.session.get('auth_mfa'):
@@ -170,12 +112,12 @@ class AuthMixin:
if not ticket_id:
ticket = None
else:
ticket = Ticket.origin_objects.get(pk=ticket_id)
ticket = get_object_or_none(Ticket, pk=ticket_id)
return ticket
def get_ticket_or_create(self, confirm_setting):
ticket = self.get_ticket()
if not ticket or ticket.status == ticket.STATUS.CLOSED:
if not ticket or ticket.status == ticket.STATUS_CLOSED:
ticket = confirm_setting.create_confirm_ticket(self.request)
self.request.session['auth_ticket_id'] = str(ticket.id)
return ticket
@@ -184,12 +126,12 @@ class AuthMixin:
ticket = self.get_ticket()
if not ticket:
raise errors.LoginConfirmOtherError('', "Not found")
if ticket.status == ticket.STATUS.OPEN:
if ticket.status == ticket.STATUS_OPEN:
raise errors.LoginConfirmWaitError(ticket.id)
elif ticket.action == ticket.ACTION.APPROVE:
elif ticket.action == ticket.ACTION_APPROVE:
self.request.session["auth_confirm"] = "1"
return
elif ticket.action == ticket.ACTION.REJECT:
elif ticket.action == ticket.ACTION_REJECT:
raise errors.LoginConfirmOtherError(
ticket.id, ticket.get_action_display()
)

View File

@@ -1,13 +1,10 @@
import uuid
from functools import partial
from django.db import models
from django.utils import timezone
from django.utils.translation import ugettext_lazy as _, ugettext as __
from rest_framework.authtoken.models import Token
from django.conf import settings
from django.utils.crypto import get_random_string
from common.db import models
from common.mixins.models import CommonModelMixin
from common.utils import get_object_or_none, get_request_ip, get_ip_city
@@ -53,7 +50,7 @@ class LoginConfirmSetting(CommonModelMixin):
def create_confirm_ticket(self, request=None):
from tickets.models import Ticket
title = _('Login confirm') + ' {}'.format(self.user)
title = _('Login confirm') + '{}'.format(self.user)
if request:
remote_addr = get_request_ip(request)
city = get_ip_city(remote_addr)
@@ -71,7 +68,7 @@ class LoginConfirmSetting(CommonModelMixin):
reviewer = self.reviewers.all()
ticket = Ticket.objects.create(
user=self.user, title=title, body=body,
type=Ticket.TYPE.LOGIN_CONFIRM,
type=Ticket.TYPE_LOGIN_CONFIRM,
)
ticket.assignees.set(reviewer)
return ticket
@@ -79,12 +76,3 @@ class LoginConfirmSetting(CommonModelMixin):
def __str__(self):
return '{} confirm'.format(self.user.username)
class SSOToken(models.JMSBaseModel):
"""
类似腾讯企业邮的 [单点登录](https://exmail.qq.com/qy_mng_logic/doc#10036)
出于安全考虑,这里的 `token` 使用一次随即过期。但我们保留每一个生成过的 `token`。
"""
authkey = models.UUIDField(primary_key=True, default=uuid.uuid4, verbose_name=_('Token'))
expired = models.BooleanField(default=False, verbose_name=_('Expired'))
user = models.ForeignKey('users.User', on_delete=models.PROTECT, verbose_name=_('User'), db_constraint=False)

View File

@@ -5,12 +5,12 @@ from rest_framework import serializers
from common.utils import get_object_or_none
from users.models import User
from users.serializers import UserProfileSerializer
from .models import AccessKey, LoginConfirmSetting, SSOToken
from .models import AccessKey, LoginConfirmSetting
__all__ = [
'AccessKeySerializer', 'OtpVerifySerializer', 'BearerTokenSerializer',
'MFAChallengeSerializer', 'LoginConfirmSettingSerializer', 'SSOTokenSerializer',
'MFAChallengeSerializer', 'LoginConfirmSettingSerializer',
]
@@ -76,9 +76,3 @@ class LoginConfirmSettingSerializer(serializers.ModelSerializer):
model = LoginConfirmSetting
fields = ['id', 'user', 'reviewers', 'date_created', 'date_updated']
read_only_fields = ['date_created', 'date_updated']
class SSOTokenSerializer(serializers.Serializer):
username = serializers.CharField(write_only=True)
login_url = serializers.CharField(read_only=True)
next = serializers.CharField(write_only=True, allow_blank=True, required=False, allow_null=True)

View File

@@ -1,27 +1,10 @@
from importlib import import_module
from django.conf import settings
from django.contrib.auth import user_logged_in
from django.core.cache import cache
from django.dispatch import receiver
from django_cas_ng.signals import cas_user_authenticated
from jms_oidc_rp.signals import openid_user_login_failed, openid_user_login_success
from .signals import post_auth_success, post_auth_failed
@receiver(user_logged_in)
def on_user_auth_login_success(sender, user, request, **kwargs):
if settings.USER_LOGIN_SINGLE_MACHINE_ENABLED:
user_id = 'single_machine_login_' + str(user.id)
session_key = cache.get(user_id)
if session_key and session_key != request.session.session_key:
session = import_module(settings.SESSION_ENGINE).SessionStore(session_key)
session.delete()
cache.set(user_id, request.session.session_key, None)
@receiver(openid_user_login_success)
def on_oidc_user_login_success(sender, request, user, **kwargs):
post_auth_success.send(sender, user=user, request=request)
@@ -30,8 +13,3 @@ def on_oidc_user_login_success(sender, request, user, **kwargs):
@receiver(openid_user_login_failed)
def on_oidc_user_login_failed(sender, username, request, reason, **kwargs):
post_auth_failed.send(sender, username=username, request=request, reason=reason)
@receiver(cas_user_authenticated)
def on_cas_user_login_success(sender, request, user, **kwargs):
post_auth_success.send(sender, user=user, request=request)

View File

@@ -7,7 +7,7 @@
{% endblock %}
{% block content %}
<form id="form" class="m-t" role="form" method="post" action="">
<form class="m-t" role="form" method="post" action="">
{% csrf_token %}
{% if form.non_field_errors %}
<div style="line-height: 17px;">
@@ -26,28 +26,17 @@
{% endif %}
</div>
<div class="form-group">
<input type="password" class="form-control" id="password" placeholder="{% trans 'Password' %}" required="">
<input id="password-hidden" type="text" style="display:none" name="{{ form.password.html_name }}">
<input type="password" class="form-control" name="{{ form.password.html_name }}" placeholder="{% trans 'Password' %}" required="">
{% if form.errors.password %}
<div class="help-block field-error">
<p class="red-fonts">{{ form.errors.password.as_text }}</p>
</div>
{% endif %}
</div>
{% if form.challenge %}
<div class="form-group">
<input type="challenge" class="form-control" id="challenge" name="{{ form.challenge.html_name }}" placeholder="{% trans 'MFA code' %}" >
{% if form.errors.challenge %}
<div class="help-block field-error">
<p class="red-fonts">{{ form.errors.challenge.as_text }}</p>
</div>
{% endif %}
</div>
{% endif %}
<div>
{{ form.captcha }}
</div>
<button type="submit" class="btn btn-primary block full-width m-b" onclick="doLogin();return false;">{% trans 'Login' %}</button>
<button type="submit" class="btn btn-primary block full-width m-b">{% trans 'Login' %}</button>
{% if demo_mode %}
<p class="text-muted font-bold" style="color: red">
@@ -57,7 +46,7 @@
<div class="text-muted text-center">
<div>
<a id="forgot_password" href="#">
<a href="{% url 'authentication:forgot-password' %}">
<small>{% trans 'Forgot password' %}?</small>
</a>
</div>
@@ -75,32 +64,4 @@
{% endif %}
</form>
<script type="text/javascript" src="/static/js/plugins/jsencrypt/jsencrypt.min.js"></script>
<script>
function encryptLoginPassword(password, rsaPublicKey){
var jsencrypt = new JSEncrypt(); //加密对象
jsencrypt.setPublicKey(rsaPublicKey); // 设置密钥
return jsencrypt.encrypt(password); //加密
}
function doLogin() {
//公钥加密
var rsaPublicKey = "{{ rsa_public_key }}"
var password =$('#password').val(); //明文密码
var passwordEncrypted = encryptLoginPassword(password, rsaPublicKey)
$('#password-hidden').val(passwordEncrypted); //返回给密码输入input
$('#form').submit();//post提交
}
var authDB = '{{ AUTH_DB }}';
var forgotPasswordUrl = "{% url 'authentication:forgot-password' %}";
$(document).ready(function () {
}).on('click', '#forgot_password', function () {
if (authDB === 'True'){
window.open(forgotPasswordUrl, "_blank")
}
else{
alert("{% trans 'You are using another authentication server, please contact your administrator' %}")
}
})
</script>
{% endblock %}

View File

@@ -67,30 +67,22 @@
</div>
<div class="box-3">
<div style="background-color: white">
{% if form.challenge %}
<div style="margin-top: 20px;padding-top: 30px;padding-left: 20px;padding-right: 20px;height: 60px">
{% else %}
<div style="margin-top: 20px;padding-top: 40px;padding-left: 20px;padding-right: 20px;height: 80px">
{% endif %}
<span style="font-size: 21px;font-weight:400;color: #151515;letter-spacing: 0;">{{ JMS_TITLE }}</span>
</div>
<div style="margin-top: 30px;padding-top: 40px;padding-left: 20px;padding-right: 20px;height: 80px">
<span style="font-size: 21px;font-weight:400;color: #151515;letter-spacing: 0;">{{ JMS_TITLE }}</span>
</div>
<div style="font-size: 12px;color: #999999;letter-spacing: 0;line-height: 18px;margin-top: 18px">
{% trans 'Welcome back, please enter username and password to login' %}
</div>
<div style="margin-bottom: 0px">
<div style="margin-bottom: 10px">
<div>
<div class="col-md-1"></div>
<div class="contact-form col-md-10" style="margin-top: 0px;height: 35px">
<div class="contact-form col-md-10" style="margin-top: 20px;height: 35px">
<form id="contact-form" action="" method="post" role="form" novalidate="novalidate">
{% csrf_token %}
{% if form.non_field_errors %}
{% if form.challenge %}
<div style="height: 50px;color: red;line-height: 17px;">
{% else %}
<div style="height: 70px;color: red;line-height: 17px;">
{% endif %}
<p class="red-fonts">{{ form.non_field_errors.as_text }}</p>
</div>
<div style="height: 70px;color: red;line-height: 17px;">
<p class="red-fonts">{{ form.non_field_errors.as_text }}</p>
</div>
{% elif form.errors.captcha %}
<p class="red-fonts">{% trans 'Captcha invalid' %}</p>
{% else %}
@@ -106,32 +98,21 @@
{% endif %}
</div>
<div class="form-group">
<input type="password" class="form-control" id="password" placeholder="{% trans 'Password' %}" required="">
<input id="password-hidden" type="text" style="display:none" name="{{ form.password.html_name }}">
<input type="password" class="form-control" name="{{ form.password.html_name }}" placeholder="{% trans 'Password' %}" required="">
{% if form.errors.password %}
<div class="help-block field-error">
<p class="red-fonts">{{ form.errors.password.as_text }}</p>
</div>
{% endif %}
</div>
{% if form.challenge %}
<div class="form-group">
<input type="challenge" class="form-control" id="challenge" name="{{ form.challenge.html_name }}" placeholder="{% trans 'MFA code' %}" >
{% if form.errors.challenge %}
<div class="help-block field-error">
<p class="red-fonts">{{ form.errors.challenge.as_text }}</p>
</div>
{% endif %}
</div>
{% endif %}
<div class="form-group" style="height: 50px;margin-bottom: 0;font-size: 13px">
{{ form.captcha }}
</div>
<div class="form-group" style="margin-top: 10px">
<button type="submit" class="btn btn-transparent" onclick="doLogin();return false;">{% trans 'Login' %}</button>
<button type="submit" class="btn btn-transparent">{% trans 'Login' %}</button>
</div>
<div style="text-align: center">
<a id="forgot_password" href="#">
<a href="{% url 'authentication:forgot-password' %}">
<small>{% trans 'Forgot password' %}?</small>
</a>
</div>
@@ -144,36 +125,6 @@
</div>
</div>
</div>
</div>
</body>
<script type="text/javascript" src="/static/js/plugins/jsencrypt/jsencrypt.min.js"></script>
<script>
function encryptLoginPassword(password, rsaPublicKey){
var jsencrypt = new JSEncrypt(); //加密对象
jsencrypt.setPublicKey(rsaPublicKey); // 设置密钥
return jsencrypt.encrypt(password); //加密
}
function doLogin() {
//公钥加密
var rsaPublicKey = "{{ rsa_public_key }}"
var password =$('#password').val(); //明文密码
var passwordEncrypted = encryptLoginPassword(password, rsaPublicKey)
$('#password-hidden').val(passwordEncrypted); //返回给密码输入input
$('#contact-form').submit();//post提交
}
var authDB = '{{ AUTH_DB }}';
var forgotPasswordUrl = "{% url 'authentication:forgot-password' %}";
$(document).ready(function () {
}).on('click', '#forgot_password', function () {
if (authDB === 'True'){
window.open(forgotPasswordUrl, "_blank")
}
else{
alert("{% trans 'You are using another authentication server, please contact your administrator' %}")
}
})
</script>
</html>

View File

@@ -1,15 +1 @@
from .utils import gen_key_pair, rsa_decrypt, rsa_encrypt
def test_rsa_encrypt_decrypt(message='test-password-$%^&*'):
""" 测试加密/解密 """
print('Need to encrypt message: {}'.format(message))
rsa_private_key, rsa_public_key = gen_key_pair()
print('RSA public key: \n{}'.format(rsa_public_key))
print('RSA private key: \n{}'.format(rsa_private_key))
message_encrypted = rsa_encrypt(message, rsa_public_key)
print('Encrypted message: {}'.format(message_encrypted))
message_decrypted = rsa_decrypt(message_encrypted, rsa_private_key)
print('Decrypted message: {}'.format(message_decrypted))

View File

@@ -8,7 +8,6 @@ from .. import api
app_name = 'authentication'
router = DefaultRouter()
router.register('access-keys', api.AccessKeyViewSet, 'access-key')
router.register('sso', api.SSOViewSet, 'sso')
urlpatterns = [

View File

@@ -21,7 +21,6 @@ urlpatterns = [
path('password/forgot/sendmail-success/', users_view.UserForgotPasswordSendmailSuccessView.as_view(),
name='forgot-password-sendmail-success'),
path('password/reset/', users_view.UserResetPasswordView.as_view(), name='reset-password'),
path('password/too-simple-flash-msg/', views.FlashPasswdTooSimpleMsgView.as_view(), name='passwd-too-simple-flash-msg'),
path('password/reset/success/', users_view.UserResetPasswordSuccessView.as_view(), name='reset-password-success'),
path('password/verify/', users_view.UserVerifyPasswordView.as_view(), name='user-verify-password'),

View File

@@ -1,46 +1,25 @@
# -*- coding: utf-8 -*-
#
import base64
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5
from Cryptodome import Random
from django.contrib.auth import authenticate
from common.utils import get_logger
logger = get_logger(__file__)
from . import errors
def gen_key_pair():
""" 生成加密key
用于登录页面提交用户名/密码时,对密码进行加密(前端)/解密(后端)
"""
random_generator = Random.new().read
rsa = RSA.generate(1024, random_generator)
rsa_private_key = rsa.exportKey().decode()
rsa_public_key = rsa.publickey().exportKey().decode()
return rsa_private_key, rsa_public_key
def check_user_valid(**kwargs):
password = kwargs.pop('password', None)
public_key = kwargs.pop('public_key', None)
username = kwargs.pop('username', None)
request = kwargs.get('request')
user = authenticate(request, username=username,
password=password, public_key=public_key)
if not user:
return None, errors.reason_password_failed
elif user.is_expired:
return None, errors.reason_user_inactive
elif not user.is_active:
return None, errors.reason_user_inactive
elif user.password_has_expired:
return None, errors.reason_password_expired
def rsa_encrypt(message, rsa_public_key):
""" 加密登录密码 """
key = RSA.importKey(rsa_public_key)
cipher = PKCS1_v1_5.new(key)
cipher_text = base64.b64encode(cipher.encrypt(message.encode())).decode()
return cipher_text
def rsa_decrypt(cipher_text, rsa_private_key=None):
""" 解密登录密码 """
if rsa_private_key is None:
# rsa_private_key 为 None可以能是API请求认证不需要解密
return cipher_text
key = RSA.importKey(rsa_private_key)
cipher = PKCS1_v1_5.new(key)
cipher_decoded = base64.b64decode(cipher_text.encode())
# Todo: 弄明白为何要以下这么写https://xbuba.com/questions/57035263
if len(cipher_decoded) == 127:
hex_fixed = '00' + cipher_decoded.hex()
cipher_decoded = base64.b16decode(hex_fixed.upper())
message = cipher.decrypt(cipher_decoded, b'error').decode()
return message
return user, ''

View File

@@ -17,22 +17,17 @@ from django.views.generic.base import TemplateView, RedirectView
from django.views.generic.edit import FormView
from django.conf import settings
from django.urls import reverse_lazy
from django.contrib.auth import BACKEND_SESSION_KEY
from common.const.front_urls import TICKET_DETAIL
from common.utils import get_request_ip, get_object_or_none
from users.utils import (
redirect_user_first_login_or_index
)
from ..const import RSA_PRIVATE_KEY, RSA_PUBLIC_KEY
from .. import mixins, errors, utils
from ..forms import get_user_login_form_cls
from .. import forms, mixins, errors
__all__ = [
'UserLoginView', 'UserLogoutView',
'UserLoginGuardView', 'UserLoginWaitConfirmView',
'FlashPasswdTooSimpleMsgView',
]
@@ -40,6 +35,8 @@ __all__ = [
@method_decorator(csrf_protect, name='dispatch')
@method_decorator(never_cache, name='dispatch')
class UserLoginView(mixins.AuthMixin, FormView):
form_class = forms.UserLoginForm
form_class_captcha = forms.UserLoginCaptchaForm
key_prefix_captcha = "_LOGIN_INVALID_{}"
redirect_field_name = 'next'
@@ -85,19 +82,15 @@ class UserLoginView(mixins.AuthMixin, FormView):
if not self.request.session.test_cookie_worked():
return HttpResponse(_("Please enable cookies and try again."))
try:
self.check_user_auth(decrypt_passwd=True)
self.check_user_auth()
except errors.AuthFailedError as e:
form.add_error(None, e.msg)
ip = self.get_request_ip()
cache.set(self.key_prefix_captcha.format(ip), 1, 3600)
form_cls = get_user_login_form_cls(captcha=True)
new_form = form_cls(data=form.data)
new_form = self.form_class_captcha(data=form.data)
new_form._errors = form.errors
context = self.get_context_data(form=new_form)
return self.render_to_response(context)
except errors.PasswdTooSimple as e:
return redirect(e.url)
self.clear_rsa_key()
return self.redirect_to_guard_view()
def redirect_to_guard_view(self):
@@ -110,29 +103,14 @@ class UserLoginView(mixins.AuthMixin, FormView):
def get_form_class(self):
ip = get_request_ip(self.request)
if cache.get(self.key_prefix_captcha.format(ip)):
return get_user_login_form_cls(captcha=True)
return self.form_class_captcha
else:
return get_user_login_form_cls()
def clear_rsa_key(self):
self.request.session[RSA_PRIVATE_KEY] = None
self.request.session[RSA_PUBLIC_KEY] = None
return self.form_class
def get_context_data(self, **kwargs):
# 生成加解密密钥对public_key传递给前端private_key存入session中供解密使用
rsa_private_key = self.request.session.get(RSA_PRIVATE_KEY)
rsa_public_key = self.request.session.get(RSA_PUBLIC_KEY)
if not all((rsa_private_key, rsa_public_key)):
rsa_private_key, rsa_public_key = utils.gen_key_pair()
rsa_public_key = rsa_public_key.replace('\n', '\\n')
self.request.session[RSA_PRIVATE_KEY] = rsa_private_key
self.request.session[RSA_PUBLIC_KEY] = rsa_public_key
context = {
'demo_mode': os.environ.get("DEMO_MODE"),
'AUTH_OPENID': settings.AUTH_OPENID,
'rsa_public_key': rsa_public_key,
'AUTH_DB': settings.AUTH_DB
}
kwargs.update(context)
return super().get_context_data(**kwargs)
@@ -163,8 +141,6 @@ class UserLoginGuardView(mixins.AuthMixin, RedirectView):
return self.format_redirect_url(self.login_confirm_url)
except errors.MFAUnsetError as e:
return e.url
except errors.PasswdTooSimple as e:
return e.url
else:
auth_login(self.request, user)
self.send_auth_signal(success=True, user=user)
@@ -188,7 +164,7 @@ class UserLoginWaitConfirmView(TemplateView):
context = super().get_context_data(**kwargs)
if ticket:
timestamp_created = datetime.datetime.timestamp(ticket.date_created)
ticket_detail_url = TICKET_DETAIL.format(id=ticket_id)
ticket_detail_url = reverse('tickets:ticket-detail', kwargs={'pk': ticket_id})
msg = _("""Wait for <b>{}</b> confirm, You also can copy link to her/him <br/>
Don't close this page""").format(ticket.assignees_display)
else:
@@ -207,12 +183,12 @@ class UserLoginWaitConfirmView(TemplateView):
class UserLogoutView(TemplateView):
template_name = 'flash_message_standalone.html'
def get_backend_logout_url(self):
backend = self.request.session.get(BACKEND_SESSION_KEY, '')
if 'OIDC' in backend:
@staticmethod
def get_backend_logout_url():
if settings.AUTH_OPENID:
return settings.AUTH_OPENID_AUTH_LOGOUT_URL_NAME
elif 'CAS' in backend:
return settings.CAS_LOGOUT_URL_NAME
# if settings.AUTH_CAS:
# return settings.CAS_LOGOUT_URL_NAME
return None
def get(self, request, *args, **kwargs):
@@ -236,16 +212,4 @@ class UserLogoutView(TemplateView):
return super().get_context_data(**kwargs)
@method_decorator(never_cache, name='dispatch')
class FlashPasswdTooSimpleMsgView(TemplateView):
template_name = 'flash_message_standalone.html'
def get(self, request, *args, **kwargs):
context = {
'title': _('Please change your password'),
'messages': _('Your password is too simple, please change it for security'),
'interval': 5,
'redirect_url': request.GET.get('redirect_url'),
'auto_redirect': True,
}
return self.render_to_response(context)

View File

@@ -1,8 +0,0 @@
from django.utils.translation import ugettext_lazy as _
from common.db.models import ChoiceSet
ADMIN = 'Admin'
USER = 'User'
AUDITOR = 'Auditor'

View File

@@ -1,2 +0,0 @@
UPDATE_NODE_TREE_LOCK_KEY = 'org_level_transaction_lock_{org_id}_assets_update_node_tree'
UPDATE_MAPPING_NODE_TASK_LOCK_KEY = 'org_level_transaction_lock_{user_id}_update_mapping_node_task'

View File

@@ -1,2 +0,0 @@
TICKET_DETAIL = '/ui/#/tickets/tickets/{id}'

View File

@@ -1,14 +0,0 @@
"""
`m2m_changed`
```
def m2m_signals_handler(action, instance, reverse, model, pk_set, using):
pass
```
"""
PRE_ADD = 'pre_add'
POST_ADD = 'post_add'
PRE_REMOVE = 'pre_remove'
POST_REMOVE = 'post_remove'
PRE_CLEAR = 'pre_clear'
POST_CLEAR = 'post_clear'

View File

@@ -1,27 +0,0 @@
from django.db.models import Aggregate
class GroupConcat(Aggregate):
function = 'GROUP_CONCAT'
template = '%(function)s(%(expressions)s %(order_by)s %(separator)s)'
allow_distinct = False
def __init__(self, expression, order_by=None, separator=',', **extra):
order_by_clause = ''
if order_by is not None:
order = 'ASC'
prefix, body = order_by[1], order_by[1:]
if prefix == '-':
order = 'DESC'
elif prefix == '+':
pass
else:
body = order_by
order_by_clause = f'ORDER BY {body} {order}'
super().__init__(
expression,
order_by=order_by_clause,
separator=f"SEPARATOR '{separator}'",
**extra
)

View File

@@ -1,84 +0,0 @@
"""
此文件作为 `django.db.models` 的 shortcut
这样做的优点与缺点为:
优点:
- 包命名都统一为 `models`
- 用户在使用的时候只导入本文件即可
缺点:
- 此文件中添加代码的时候,注意不要跟 `django.db.models` 中的命名冲突
"""
import uuid
from django.db.models import *
from django.db.models.functions import Concat
from django.utils.translation import ugettext_lazy as _
class Choice(str):
def __new__(cls, value, label=''): # `deepcopy` 的时候不会传 `label`
self = super().__new__(cls, value)
self.label = label
return self
class ChoiceSetType(type):
def __new__(cls, name, bases, attrs):
_choices = []
collected = set()
new_attrs = {}
for k, v in attrs.items():
if isinstance(v, tuple):
v = Choice(*v)
assert v not in collected, 'Cannot be defined repeatedly'
_choices.append(v)
collected.add(v)
new_attrs[k] = v
for base in bases:
if hasattr(base, '_choices'):
for c in base._choices:
if c not in collected:
_choices.append(c)
collected.add(c)
new_attrs['_choices'] = _choices
new_attrs['_choices_dict'] = {c: c.label for c in _choices}
return type.__new__(cls, name, bases, new_attrs)
def __contains__(self, item):
return self._choices_dict.__contains__(item)
def __getitem__(self, item):
return self._choices_dict.__getitem__(item)
def get(self, item, default=None):
return self._choices_dict.get(item, default)
@property
def choices(self):
return [(c, c.label) for c in self._choices]
class ChoiceSet(metaclass=ChoiceSetType):
choices = None # 用于 Django Model 中的 choices 配置, 为了代码提示在此声明
class JMSBaseModel(Model):
created_by = CharField(max_length=32, null=True, blank=True, verbose_name=_('Created by'))
updated_by = CharField(max_length=32, null=True, blank=True, verbose_name=_('Updated by'))
date_created = DateTimeField(auto_now_add=True, null=True, blank=True, verbose_name=_('Date created'))
date_updated = DateTimeField(auto_now=True, verbose_name=_('Date updated'))
class Meta:
abstract = True
class JMSModel(JMSBaseModel):
id = UUIDField(default=uuid.uuid4, primary_key=True)
class Meta:
abstract = True
def concated_display(name1, name2):
return Concat(F(name1), Value('('), F(name2), Value(')'))

View File

@@ -1,27 +0,0 @@
from common.utils import get_logger
logger = get_logger(__file__)
def get_object_if_need(model, pk):
if not isinstance(pk, model):
try:
return model.objects.get(id=pk)
except model.DoesNotExist as e:
logger.error(f'DoesNotExist: <{model.__name__}:{pk}> not exist')
raise e
return pk
def get_objects_if_need(model, pks):
if not pks:
return pks
if not isinstance(pks[0], model):
objs = list(model.objects.filter(id__in=pks))
if len(objs) != len(pks):
pks = set(pks)
exists_pks = {o.id for o in objs}
not_found_pks = ','.join(pks - exists_pks)
logger.error(f'DoesNotExist: <{model.__name__}: {not_found_pks}>')
return objs
return pks

View File

@@ -1,42 +0,0 @@
from rest_framework.viewsets import GenericViewSet, ModelViewSet
from rest_framework_bulk import BulkModelViewSet
from ..mixins.api import (
SerializerMixin2, QuerySetMixin, ExtraFilterFieldsMixin, PaginatedResponseMixin,
RelationMixin, AllowBulkDestoryMixin
)
class JmsGenericViewSet(SerializerMixin2,
QuerySetMixin,
ExtraFilterFieldsMixin,
PaginatedResponseMixin,
GenericViewSet):
pass
class JMSModelViewSet(SerializerMixin2,
QuerySetMixin,
ExtraFilterFieldsMixin,
PaginatedResponseMixin,
ModelViewSet):
pass
class JMSBulkModelViewSet(SerializerMixin2,
QuerySetMixin,
ExtraFilterFieldsMixin,
PaginatedResponseMixin,
AllowBulkDestoryMixin,
BulkModelViewSet):
pass
class JMSBulkRelationModelViewSet(SerializerMixin2,
QuerySetMixin,
ExtraFilterFieldsMixin,
PaginatedResponseMixin,
RelationMixin,
AllowBulkDestoryMixin,
BulkModelViewSet):
pass

View File

@@ -1,49 +0,0 @@
from django.core.exceptions import PermissionDenied, ObjectDoesNotExist as DJObjectDoesNotExist
from django.http import Http404
from django.utils.translation import gettext
from rest_framework import exceptions
from rest_framework.views import set_rollback
from rest_framework.response import Response
from common.exceptions import JMSObjectDoesNotExist
from common.utils import get_logger
logger = get_logger(__name__)
def extract_object_name(exc, index=0):
"""
`index` 是从 0 开始数的, 比如:
`No User matches the given query.`
提取 `User``index=1`
"""
(msg, *_) = exc.args
return gettext(msg.split(sep=' ', maxsplit=index + 1)[index])
def common_exception_handler(exc, context):
logger.exception('')
if isinstance(exc, Http404):
exc = JMSObjectDoesNotExist(object_name=extract_object_name(exc, 1))
elif isinstance(exc, PermissionDenied):
exc = exceptions.PermissionDenied()
elif isinstance(exc, DJObjectDoesNotExist):
exc = JMSObjectDoesNotExist(object_name=extract_object_name(exc, 0))
if isinstance(exc, exceptions.APIException):
headers = {}
if getattr(exc, 'auth_header', None):
headers['WWW-Authenticate'] = exc.auth_header
if getattr(exc, 'wait', None):
headers['Retry-After'] = '%d' % exc.wait
if isinstance(exc.detail, str) and isinstance(exc.get_codes(), str):
data = {'detail': exc.detail, 'code': exc.get_codes()}
else:
data = exc.detail
set_rollback()
return Response(data, status=exc.status_code, headers=headers)
return None

View File

@@ -1,43 +0,0 @@
from uuid import UUID
from rest_framework.fields import get_attribute
from rest_framework.relations import ManyRelatedField, PrimaryKeyRelatedField, MANY_RELATION_KWARGS
class GroupConcatedManyRelatedField(ManyRelatedField):
def get_attribute(self, instance):
if hasattr(instance, 'pk') and instance.pk is None:
return []
attr = self.source_attrs[-1]
# `gc` 是 `GroupConcat` 的缩写
gc_attr = f'gc_{attr}'
if hasattr(instance, gc_attr):
gc_value = getattr(instance, gc_attr)
if isinstance(gc_value, str):
return [UUID(pk) for pk in set(gc_value.split(','))]
else:
return ''
relationship = get_attribute(instance, self.source_attrs)
return relationship.all() if hasattr(relationship, 'all') else relationship
class GroupConcatedPrimaryKeyRelatedField(PrimaryKeyRelatedField):
@classmethod
def many_init(cls, *args, **kwargs):
list_kwargs = {'child_relation': cls(*args, **kwargs)}
for key in kwargs:
if key in MANY_RELATION_KWARGS:
list_kwargs[key] = kwargs[key]
return GroupConcatedManyRelatedField(**list_kwargs)
def to_representation(self, value):
if self.pk_field is not None:
return self.pk_field.to_representation(value.pk)
if hasattr(value, 'pk'):
return value.pk
else:
return value

View File

@@ -6,27 +6,18 @@ import chardet
import codecs
import unicodecsv
from django.utils.translation import ugettext as _
from rest_framework.parsers import BaseParser
from rest_framework.exceptions import ParseError, APIException
from rest_framework import status
from rest_framework.exceptions import ParseError
from common.utils import get_logger
logger = get_logger(__file__)
class CsvDataTooBig(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_code = 'csv_data_too_big'
default_detail = _('The max size of CSV is %d bytes')
class JMSCSVParser(BaseParser):
"""
Parses CSV file to serializer data
"""
CSV_UPLOAD_MAX_SIZE = 1024 * 1024 * 10
media_type = 'text/csv'
@@ -47,39 +38,31 @@ class JMSCSVParser(BaseParser):
yield row
@staticmethod
def _get_fields_map(serializer_cls):
def _get_fields_map(serializer):
fields_map = {}
fields = serializer_cls().fields
fields = serializer.fields
fields_map.update({v.label: k for k, v in fields.items()})
fields_map.update({k: k for k, _ in fields.items()})
return fields_map
@staticmethod
def _replace_chinese_quot(str_):
trans_table = str.maketrans({
'': '"',
'': '"',
'': '"',
'': '"',
'\'': '"'
})
return str_.translate(trans_table)
@classmethod
def _process_row(cls, row):
def _process_row(row):
"""
构建json数据前的行处理
"""
_row = []
for col in row:
# 列表转换
if isinstance(col, str) and col.startswith('[') and col.endswith(']'):
col = cls._replace_chinese_quot(col)
if isinstance(col, str) and col.find("[") != -1 and col.find("]") != -1:
# 替换中文格式引号
col = col.replace("", '"').replace("", '"').\
replace("", '"').replace('', '"').replace("'", '"')
col = json.loads(col)
# 字典转换
if isinstance(col, str) and col.startswith("{") and col.endswith("}"):
col = cls._replace_chinese_quot(col)
if isinstance(col, str) and col.find("{") != -1 and col.find("}") != -1:
# 替换中文格式引号
col = col.replace("", '"').replace("", '"'). \
replace("", '"').replace('', '"').replace("'", '"')
col = json.loads(col)
_row.append(col)
return _row
@@ -99,19 +82,11 @@ class JMSCSVParser(BaseParser):
def parse(self, stream, media_type=None, parser_context=None):
parser_context = parser_context or {}
try:
view = parser_context['view']
meta = view.request.META
serializer_cls = view.get_serializer_class()
serializer = parser_context["view"].get_serializer()
except Exception as e:
logger.debug(e, exc_info=True)
raise ParseError('The resource does not support imports!')
content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)))
if content_length > self.CSV_UPLOAD_MAX_SIZE:
msg = CsvDataTooBig.default_detail % self.CSV_UPLOAD_MAX_SIZE
logger.error(msg)
raise CsvDataTooBig(msg)
try:
stream_data = stream.read()
stream_data = stream_data.strip(codecs.BOM_UTF8)
@@ -121,7 +96,7 @@ class JMSCSVParser(BaseParser):
rows = self._gen_rows(binary, charset=encoding)
header = next(rows)
fields_map = self._get_fields_map(serializer_cls)
fields_map = self._get_fields_map(serializer)
header = [fields_map.get(name.strip('*'), '') for name in header]
data = []

View File

@@ -1,25 +0,0 @@
from rest_framework.serializers import Serializer
from rest_framework.serializers import ModelSerializer
from rest_framework import serializers
from rest_framework_bulk.serializers import BulkListSerializer
from common.mixins.serializers import BulkSerializerMixin
from common.mixins import BulkListSerializerMixin
__all__ = ['EmptySerializer', 'BulkModelSerializer']
class EmptySerializer(Serializer):
pass
class BulkModelSerializer(BulkSerializerMixin, ModelSerializer):
pass
class AdaptedBulkListSerializer(BulkListSerializerMixin, BulkListSerializer):
pass
class CeleryTaskSerializer(serializers.Serializer):
task = serializers.CharField(read_only=True)

Some files were not shown because too many files have changed in this diff Show More