Compare commits

...

106 Commits
v2.5 ... v2.6.0

Author SHA1 Message Date
Jiangjie.Bai
3f4877f26b Merge pull request #5295 from jumpserver/dev
fix: 修复命令列表过滤字段文案`会话ID`
2020-12-17 18:26:54 +08:00
Bai
0e4d778335 fix: 修复命令列表过滤字段文案会话ID 2020-12-17 18:25:48 +08:00
Jiangjie.Bai
52d20080ff Merge pull request #5293 from jumpserver/dev
fix: 修改系统监控问题
2020-12-17 17:36:06 +08:00
Bai
ed8d72c06b fix: 修改系统监控问题 2020-12-17 17:34:37 +08:00
Jiangjie.Bai
5e9e3ec6f6 Merge pull request #5288 from jumpserver/dev
chore: Merge master from dev
2020-12-17 14:56:25 +08:00
xinwen
4f5f92deb8 fix: 批量删除管理用户报错信息太丑 2020-12-17 14:47:37 +08:00
xinwen
d2a15ee702 fix: 申请资产工单如果系统用户没填内容不推荐系统用户 2020-12-17 14:35:08 +08:00
Bai
a3a591da4b fix: 修复命令导出excel格式报错 2020-12-17 14:31:34 +08:00
Bai
3f2925116e fix: 修复metrics获取terminal过滤is_deleted字段 2020-12-17 11:30:29 +08:00
xinwen
c3e2e536e0 fix: 【用户管理】-创建用户组-可将系统审计员加入到用户组 #579 2020-12-17 10:33:45 +08:00
Jiangjie.Bai
8c133d5fdb Merge pull request #5278 from jumpserver/dev
chore: Merge master from dev
2020-12-16 18:50:49 +08:00
xinwen
89d8efe0f1 fix: perms.signals_handler.on_application_permission_applications_changed 修改名字 2020-12-16 18:49:20 +08:00
Bai
54303ea33f fix: 修复节点创建时更新孩子full_value日志输出问题 2020-12-16 18:37:54 +08:00
Bai
4dcd8dd8dd fix: 修复节点创建时更新孩子full_value日志输出问题 2020-12-16 18:37:54 +08:00
老广
4f04a7d258 Merge pull request #5280 from jumpserver/pr@dev@fix_auto_push
fix: 推送系统用户时 AdHocExecution id 重复
2020-12-16 17:36:45 +08:00
xinwen
bf308e24b6 fix: 推送系统用户时 AdHocExecution id 重复 2020-12-16 17:31:37 +08:00
Bai
b3642f3ff4 fix: 修复LDAP用户登录(未找到)时循环调用问题 2020-12-16 12:01:57 +08:00
xinwen
3aed4955c8 fix: 远程应用授权的一些问题 2020-12-16 12:00:53 +08:00
fit2bot
9a5f9a9c92 fix: 应用授权不会自动推送的bug (#5271)
Co-authored-by: xinwen <coderWen@126.com>
2020-12-16 10:23:37 +08:00
Orange
6d5bec1ef2 Merge pull request #5269 from jumpserver/dev
chore: Merge master from dev
2020-12-15 20:35:51 +08:00
Bai
e93fd1fd44 fix: 删除终端列表state的默认值0 2020-12-15 20:34:47 +08:00
xinwen
7bf37611bd fix: 系统审计员不应该能添加到组 2020-12-15 19:24:25 +08:00
xinwen
b8ec4bfaa5 fix: 日志审计-操作日志中搜索-按动作搜索:需修改文字 删除文件-->删除 #524 2020-12-15 18:38:23 +08:00
Bai
58b6293b76 fix: 组件监控添加offline数量 2020-12-15 18:37:16 +08:00
xinwen
8e12eebceb fix: 获得 oidc acs 等认证方式失败 2020-12-15 18:36:37 +08:00
Bai
72d6ea43fa fix: 修改命令列表过滤参数session 2020-12-15 18:34:05 +08:00
Bai
deedd49dc5 fix: 修复命令记录导出excel文件格式未定义的问题 2020-12-15 18:07:11 +08:00
fit2bot
a36e6fbf84 fix: 修改判断会话活跃逻辑;不必要判断协议 (#5262)
* fix: 修改判断会话活跃逻辑;不必要判断协议

* fix: 修改导入task问题

Co-authored-by: Bai <bugatti_it@163.com>
2020-12-15 18:06:35 +08:00
Bai
b57453cc3c fix: 修复命令列表过滤使用session_id字段 2020-12-15 18:05:42 +08:00
Jiangjie.Bai
62f2909d59 Merge pull request #5256 from jumpserver/dev
chore: Merge master from dev
2020-12-15 14:20:23 +08:00
xinwen
0d469ff95b fix(orgs): 用户离开组织后授权的资产没主动刷新 2020-12-15 14:00:51 +08:00
xinwen
ca883f1fb4 fix: 工单申请资产审批时系统用户没有推荐 2020-12-15 13:06:55 +08:00
fit2bot
6e0fbd78e7 fix: 修复prometheus_metricsAPI数据获取bug;修复组件注册type为空bug (#5253)
* fix: 修复prometheus_metricsAPI数据获取bug;修复组件注册type为空bug

* fix: 修改审计migrations/userloginlog/backend的verbose_name字段

Co-authored-by: Bai <bugatti_it@163.com>
2020-12-15 13:00:59 +08:00
ibuler
0813cff74f perf: 优化 docs swagger,不再要求debug 2020-12-14 11:39:02 +08:00
ibuler
ff428b84f9 fix(assets): 修复asset更新子节点时的日志错误 2020-12-14 11:33:34 +08:00
ibuler
d0c9aa2c55 fix(assets): 修复更新孩子节点时的log error 2020-12-14 11:33:34 +08:00
老广
1d5e603c0d Merge pull request #5231 from jumpserver/dev
chore(merge): 合并 dev 到 master
2020-12-11 19:29:45 +08:00
ibuler
ddbbc8df17 fix(docker): 修复Dockerfile中 echo引起的sh和bash换行兼容问题 2020-12-11 19:26:46 +08:00
Bai
90df404931 fix: 修复swagger问题 2020-12-11 19:02:33 +08:00
Bai
b9cbff1a5f del: 删除测试prometheus相关代码 2020-12-11 19:02:33 +08:00
ibuler
b9717eece3 fix: 修改访问swagger会产生的错误 2020-12-11 18:30:20 +08:00
Bai
f9cf2a243b fix: 修复settings中搜索LDAP用户重复问题 2020-12-11 18:26:10 +08:00
Bai
e056430fce fix: 修复只配置DC域时,LDAP用户认证失败的问题 2020-12-11 18:26:10 +08:00
Jiangjie.Bai
2b2821c0a1 Merge pull request #5223 from jumpserver/dev
chore(merge): 合并 dev 到 master
2020-12-11 16:53:36 +08:00
Bai
213221beae perf: 修改BasePermissionViewSet的custom_filter_fields 2020-12-11 16:19:18 +08:00
Bai
2db9c90a74 feat: 修改翻译:认证方式 2020-12-11 15:49:39 +08:00
Bai
8ced6f1168 fix: 用户ProfileAPI设置is_first_login不是可读写 2020-12-11 15:44:17 +08:00
Bai
6703ab9a77 perf: 添加BasePermissionsViewSet,支持搜索过滤 2020-12-11 15:44:17 +08:00
老广
2fc6e6cd54 Merge pull request #5213 from jumpserver/dev
chore(merge): 合并dev到master
2020-12-10 23:03:35 +08:00
Bai
2176fd8fac feat: 更新翻译 2020-12-10 21:30:56 +08:00
fit2bot
856e7c16e5 feat: 添加组件监控;TerminalModel添加type字段; (#5206)
* feat: 添加组件监控;TerminalModel添加type字段;

* feat: Terminal序列类添加type字段

* feat: Terminal序列类添加type字段为只读

* feat: 修改组件status文案

* feat: 取消上传组件状态序列类count字段

* reactor: 修改termina/models目录结构

* feat: 修改ComponentTypeChoices

* feat: 取消考虑CoreComponent类型

* feat: 修改Terminal status判断逻辑

* feat: 终端列表添加status过滤; 组件状态序列类添加default值

* feat: 添加PrometheusMetricsAPI

* feat: 修改PrometheusMetricsAPI

Co-authored-by: Bai <bugatti_it@163.com>
2020-12-10 20:50:22 +08:00
fit2bot
d4feaf1e08 fix: 修复由于更新django captch版本引起的css丢失问题 (#5204)
* fix: 修复由于更新django captch版本引起的css丢失问题

* perf: 优化验证码的高度

Co-authored-by: ibuler <ibuler@qq.com>
2020-12-10 20:48:10 +08:00
fit2bot
5aee2ce3db chore: 升级依赖库版本 (#5205)
* chore: 升级依赖库版本

* fix: 几个库回退几个版本

Co-authored-by: ibuler <ibuler@qq.com>
2020-12-10 20:46:45 +08:00
xinwen
4424c4bde2 perf(asset): 手动启动节点资产数量自检程序时区分组织 2020-12-10 18:17:01 +08:00
fit2bot
5863e3e008 perf(asset): 资产树,右击增加计算节点数量的菜单,可以让后台去计算 #527 (#5207)
Co-authored-by: xinwen <coderWen@126.com>
2020-12-10 17:12:39 +08:00
xinwen
79a371eb6c perf(auth): 密码过期后,走重置密码流程 #530 2020-12-10 16:06:26 +08:00
fit2bot
7c7de96158 feat(login): 登录日志要体现用哪个backend登录的 #4472 (#5199)
Co-authored-by: xinwen <coderWen@126.com>
2020-12-09 18:43:13 +08:00
ibuler
80b03e73f6 feat(celery): 添加celery的health check接口 2020-12-09 18:06:34 +08:00
fit2bot
32dbab2e34 perf: 数据库应用database字段添加allow_null=True (#5196)
* perf: 数据库应用database字段修改为required

* perf: 数据库应用database字段添加allow_null=True

Co-authored-by: Bai <bugatti_it@163.com>
2020-12-09 13:44:07 +08:00
ibuler
b189e363cc revert(system): 暂时去掉system组件 2020-12-09 11:24:44 +08:00
Bai
4c3a655239 perf: 用户序列类禁止修改source字段 2020-12-08 20:54:33 +08:00
Bai
5533114db5 feat: 用户授权应用树按组织节点进行区分 2020-12-08 20:37:07 +08:00
Jiangjie.Bai
4c469afa95 feat: 取消资产配置相关字段只读模式 (#5182)
* feat: 取消资产配置相关字段只读模式

* feat: 取消资产配置相关字段只读模式
2020-12-08 20:32:48 +08:00
fit2bot
2ccc5beeda perf(Dockerfile): 不再使用zh_CN.UTF-8, en_US.UTF-8 应该也是可以的 (#5190)
* perf(Dockerfile): 不再使用zh_CN.UTF-8, en_US.UTF-8 应该也是可以的

* fix: 还原回原来的LANG设置

* perf: 合并层数

* feat: 修改Dockerfile

Co-authored-by: ibuler <ibuler@qq.com>
Co-authored-by: Bai <bugatti_it@163.com>
2020-12-08 20:22:48 +08:00
xinwen
4b67d6925e feat(asset): api 添加推送系统用户到多个资产 2020-12-08 18:08:44 +08:00
fit2bot
dd979f582a stash (#5178)
* Dev (#4791)

* fix(xpack): 修复last login太长的问题 (#4786)

Co-authored-by: ibuler <ibuler@qq.com>

* perf: 更新密码中也发送邮件 (#4789)

Co-authored-by: ibuler <ibuler@qq.com>

* fix(terminal): 修复获取螺旋的异步api

* fix(terminal): 修复有的录像存储有问题的导致下载录像的bug

* fix(orgs): 修复组织添加用户bug

* perf(requirements): 修改jms-storage==0.0.34 (#4797)

Co-authored-by: Bai <bugatti_it@163.com>

Co-authored-by: fit2bot <68588906+fit2bot@users.noreply.github.com>
Co-authored-by: ibuler <ibuler@qq.com>
Co-authored-by: Bai <bugatti_it@163.com>

* stash

* feat(system): 添加系统app

* stash

* fix: 修复一些bug

Co-authored-by: xinwen <coderWen@126.com>
Co-authored-by: ibuler <ibuler@qq.com>
Co-authored-by: Bai <bugatti_it@163.com>
Co-authored-by: Jiangjie.Bai <32935519+BaiJiangJie@users.noreply.github.com>
2020-12-08 14:26:18 +08:00
Bai
042ea5e137 feat: 授权应用树API返回org_name字段 2020-12-08 11:01:09 +08:00
ibuler
2a6f68c7ba revert: 还原原来的jms,自动运行migraionts 2020-12-07 19:16:27 +08:00
fit2bot
43b5e97b95 feat(excel): 添加Excel导入/导出 (#5124)
* refactor(drf_renderer): 添加 ExcelRenderer 支持导出excel文件格式; 优化CSVRenderer, 抽象 BaseRenderer

* perf(renderer): 支持导出资源详情

* refactor(drf_parser): 添加 ExcelParser 支持导入excel文件格式; 优化CSVParser, 抽象 BaseParser

* refactor(drf_parser): 添加 ExcelParser 支持导入excel文件格式; 优化CSVParser, 抽象 BaseParser 2

* perf(renderer): 捕获renderer处理异常

* perf: 添加excel依赖包

* perf(drf): 优化导入导出错误日志

* perf: 添加依赖包 pyexcel-io==0.6.4

* perf: 添加依赖包pyexcel-xlsx==0.6.0

* feat: 修改drf/renderer&parser变量命名

* feat: 修改drf/renderer的bug

* feat: 修改drf/renderer&parser变量命名

Co-authored-by: Bai <bugatti_it@163.com>
2020-12-07 15:23:05 +08:00
ibuler
619b521ea1 fix: 修改语言i18n 2020-12-05 13:22:22 +08:00
fit2bot
3447eeda68 fix(applications): 修改attrs不能为null (#5172)
Co-authored-by: Bai <bugatti_it@163.com>
2020-12-04 13:10:59 +08:00
fit2bot
75ef413ea5 fix(applications): 修改attrs不能为null (#5171)
Co-authored-by: Bai <bugatti_it@163.com>
2020-12-04 10:24:10 +08:00
fit2bot
662c9092dc reactor(dockerfile): 使用debian构建docker (#5169)
Co-authored-by: ibuler <ibuler@qq.com>
2020-12-03 19:14:28 +08:00
ibuler
c8d54b28e2 perf: 优化变量命名 2020-12-03 14:23:16 +08:00
ibuler
96cd307d1f perf: 优化entrypoint.sh 2020-12-03 14:23:16 +08:00
ibuler
6385cb3f86 perf: 优化启动,不再自动运行migrations 2020-12-03 14:23:16 +08:00
Bai
36e9d8101a fix: 添加迁移文件: Node ordering 2020-12-03 11:16:27 +08:00
ibuler
3354ab8ce9 fix(req): fix wheel version 2020-12-03 10:47:21 +08:00
Bai
89ec6ba6ef fix: Node ordering [parent_key, value]; 修复默认组织Default节点显示问题(存在key为0的Default节点) 2020-12-03 10:45:25 +08:00
Bai
af40e46a75 fix: 优化迁移Default节点 2020-12-03 10:29:00 +08:00
Bai
86fcd3c251 fix: 添加迁移文件(如果需要,将Default节点的key从0修改为1) 2020-12-03 10:29:00 +08:00
fit2bot
c2d5928273 build(pip): 锁定pip版本 (#5152)
* build(pip): 锁定pip版本

* fix: 锁定pip版本

* fix(req): 锁定加密库版本

* fix(build): 引用pip缓存

Co-authored-by: ibuler <ibuler@qq.com>
2020-12-02 11:09:39 +08:00
xinwen
e656ba70ec fix(assets): 推送动态系统用户未指定 username 取全部 usernames 2020-12-01 20:08:54 +08:00
xinwen
bb807e6251 fix(perms): 新建授权时动态用户可能推送不成功 2020-12-01 20:08:54 +08:00
Bai
bbd6cae3d7 perf(org): 优化获取org_name字段 2020-11-30 15:21:56 +08:00
Bai
c3b09dd800 perf(perms): 优化用户授权树返回org_name字段;添加thread_local属性org_mapper减少查询次数 2020-11-30 15:21:56 +08:00
Bai
6d427b9834 fix: 禁止删除组织根节点 2020-11-30 14:19:20 +08:00
xinwen
610aaf5244 fix(assets): 动态系统用户和用户关系变化时没有推送到资产 2020-11-26 15:20:09 +08:00
xinwen
df2f1b3e6e perf(User): 用户列表在大规模数据情况下慢 2020-11-26 12:32:18 +08:00
fit2bot
f26b7a470a perf(celery-task): 优化检查节点资产数量的 Celery 任务 (#5052)
Co-authored-by: xinwen <coderWen@126.com>
2020-11-25 17:30:07 +08:00
xinwen
a4667f3312 fix(Node): Node 保存的时候,在信号里设置 parent_key 2020-11-25 16:35:12 +08:00
xinwen
91081d9423 refactor(perms): 在动态用户所绑定的授权规则中,如授权给用户组,当用户组增加成员后,动态系统用户下没有相应增加用户,因此也不会自动推送 (#5084) (#5086) 2020-11-24 19:31:45 +08:00
fit2bot
3041697edc fix(orgs): 兼容旧的组织用户关系接口 (#5088)
Co-authored-by: xinwen <coderWen@126.com>
2020-11-24 19:09:14 +08:00
xinwen
75d7530ea5 fix(perms): 在动态用户所绑定的授权规则中,如授权给用户组,当用户组增加成员后,动态系统用户下没有相应增加用户,因此也不会自动推送 2020-11-24 10:27:03 +08:00
ibuler
975cc41bce perf(build): 优化使用pip mirror 2020-11-22 18:16:45 +08:00
ibuler
439999381d perf(build): 优化构建时用的mirror 2020-11-22 17:51:24 +08:00
xinwen
39ab5978be perf(perms): 获取用户所有授权时转换成 list 2020-11-22 17:24:00 +08:00
ibuler
7be7c8cee1 fix(perms): 修复我的资产页面问题 2020-11-22 16:55:21 +08:00
xinwen
68b22cbdec fix(perms): 修复用户组授权树与资产问题 2020-11-22 15:03:03 +08:00
xinwen
a7c704bea3 perf(celery-task): 优化检查节点资产数量的 Celery 任务 2020-11-22 11:36:10 +08:00
xinwen
21993b0d89 perf(perms): 优化用户授权资产列表加载速度 2020-11-22 11:35:13 +08:00
xinwen
73ccf3be5f fix(perms): 当用户授权为空时,清空旧的授权树 2020-11-22 11:26:39 +08:00
ibuler
bf3056abc4 fix(django3): 修复django3兼容问题 2020-11-20 15:25:37 +08:00
xinwen
f2fd9f5990 perf(assets): 限制搜索授权资产返回的条数 2020-11-20 15:24:33 +08:00
fit2bot
6d39a51c36 [fix]: 兼容django 3 (#5038)
* chore(django): 修改版本依赖

* [fix]: 兼容django 3

* fix(merge): 去掉不用的JSONField

* fix(requirements): 修改加密库的版本

Co-authored-by: ibuler <ibuler@qq.com>
2020-11-19 15:50:31 +08:00
xinwen
7fa94008c9 fix(old-api): 调整旧的组织与用户关联接口 2020-11-19 15:21:16 +08:00
130 changed files with 2983 additions and 1426 deletions

View File

@@ -1,5 +1,6 @@
FROM registry.fit2cloud.com/public/python:v3 as stage-build
MAINTAINER Jumpserver Team <ibuler@qq.com>
# 编译代码
FROM python:3.8.6-slim as stage-build
MAINTAINER JumpServer Team <ibuler@qq.com>
ARG VERSION
ENV VERSION=$VERSION
@@ -8,33 +9,38 @@ ADD . .
RUN cd utils && bash -ixeu build.sh
FROM registry.fit2cloud.com/public/python:v3
# 构建运行时环境
FROM python:3.8.6-slim
ARG PIP_MIRROR=https://pypi.douban.com/simple
ENV PIP_MIRROR=$PIP_MIRROR
ARG MYSQL_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/mysql/yum/mysql57-community-el6/
ENV MYSQL_MIRROR=$MYSQL_MIRROR
ARG PIP_JMS_MIRROR=https://pypi.douban.com/simple
ENV PIP_JMS_MIRROR=$PIP_JMS_MIRROR
WORKDIR /opt/jumpserver
COPY ./requirements ./requirements
RUN useradd jumpserver
RUN yum -y install epel-release && \
echo -e "[mysql]\nname=mysql\nbaseurl=${MYSQL_MIRROR}\ngpgcheck=0\nenabled=1" > /etc/yum.repos.d/mysql.repo
RUN yum -y install $(cat requirements/rpm_requirements.txt)
RUN pip install --upgrade pip setuptools==49.6.0 wheel -i ${PIP_MIRROR} && \
pip config set global.index-url ${PIP_MIRROR}
RUN pip install $(grep 'jms' requirements/requirements.txt) -i https://pypi.org/simple
RUN pip install -r requirements/requirements.txt
COPY ./requirements/deb_buster_requirements.txt ./requirements/deb_buster_requirements.txt
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list \
&& sed -i 's/security.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list \
&& apt update \
&& grep -v '^#' ./requirements/deb_buster_requirements.txt | xargs apt -y install \
&& localedef -c -f UTF-8 -i zh_CN zh_CN.UTF-8 \
&& cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
COPY ./requirements/requirements.txt ./requirements/requirements.txt
RUN pip install --upgrade pip==20.2.4 setuptools==49.6.0 wheel==0.34.2 -i ${PIP_MIRROR} \
&& pip config set global.index-url ${PIP_MIRROR} \
&& pip install --no-cache-dir $(grep 'jms' requirements/requirements.txt) -i ${PIP_JMS_MIRROR} \
&& pip install --no-cache-dir -r requirements/requirements.txt
COPY --from=stage-build /opt/jumpserver/release/jumpserver /opt/jumpserver
RUN mkdir -p /root/.ssh/ && echo -e "Host *\n\tStrictHostKeyChecking no\n\tUserKnownHostsFile /dev/null" > /root/.ssh/config
RUN mkdir -p /root/.ssh/ \
&& echo "Host *\n\tStrictHostKeyChecking no\n\tUserKnownHostsFile /dev/null" > /root/.ssh/config
RUN echo > config.yml
VOLUME /opt/jumpserver/data
VOLUME /opt/jumpserver/logs
ENV LANG=zh_CN.UTF-8
ENV LC_ALL=zh_CN.UTF-8
EXPOSE 8070
EXPOSE 8080

View File

@@ -1,4 +1,7 @@
import uuid
from common.exceptions import JMSException
from orgs.models import Organization
from .. import models
@@ -6,6 +9,8 @@ class ApplicationAttrsSerializerViewMixin:
def get_serializer_class(self):
serializer_class = super().get_serializer_class()
if getattr(self, 'swagger_fake_view', False):
return serializer_class
app_type = self.request.query_params.get('type')
app_category = self.request.query_params.get('category')
type_options = list(dict(models.Category.get_all_type_serializer_mapper()).keys())
@@ -37,12 +42,16 @@ class ApplicationAttrsSerializerViewMixin:
if app_type:
# action: create / update / list / retrieve / metadata
attrs_cls = models.Category.get_type_serializer_cls(app_type)
class_name = 'ApplicationDynamicSerializer{}'.format(app_type.title())
elif app_category:
# action: list / retrieve / metadata
attrs_cls = models.Category.get_category_serializer_cls(app_category)
class_name = 'ApplicationDynamicSerializer{}'.format(app_category.title())
else:
attrs_cls = models.Category.get_no_password_serializer_cls()
return type('ApplicationDynamicSerializer', (serializer_class,), {'attrs': attrs_cls()})
class_name = 'ApplicationDynamicSerializer'
cls = type(class_name, (serializer_class,), {'attrs': attrs_cls()})
return cls
class SerializeApplicationToTreeNodeMixin:
@@ -85,11 +94,46 @@ class SerializeApplicationToTreeNodeMixin:
'meta': {'type': 'k8s_app'}
}
def _serialize(self, application):
def _serialize_application(self, application):
method_name = f'_serialize_{application.category}'
data = getattr(self, method_name)(application)
data.update({
'pId': application.org.id,
'org_name': application.org_name
})
return data
def serialize_applications(self, applications):
data = [self._serialize(application) for application in applications]
data = [self._serialize_application(application) for application in applications]
return data
@staticmethod
def _serialize_organization(org):
return {
'id': org.id,
'name': org.name,
'title': org.name,
'pId': '',
'open': True,
'isParent': True,
'meta': {
'type': 'node'
}
}
def serialize_organizations(self, organizations):
data = [self._serialize_organization(org) for org in organizations]
return data
@staticmethod
def filter_organizations(applications):
organizations_id = set(applications.values_list('org_id', flat=True))
organizations = [Organization.get_instance(org_id) for org_id in organizations_id]
return organizations
def serialize_applications_with_org(self, applications):
organizations = self.filter_organizations(applications)
data_organizations = self.serialize_organizations(organizations)
data_applications = self.serialize_applications(applications)
data = data_organizations + data_applications
return data

View File

@@ -0,0 +1,18 @@
# Generated by Django 3.1 on 2020-11-19 03:10
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('applications', '0006_application'),
]
operations = [
migrations.AlterField(
model_name='application',
name='attrs',
field=models.JSONField(),
),
]

View File

@@ -2,7 +2,6 @@ from itertools import chain
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django_mysql.models import JSONField, QuerySet
from orgs.mixins.models import OrgModelMixin
from common.mixins import CommonModelMixin
@@ -123,7 +122,7 @@ class Application(CommonModelMixin, OrgModelMixin):
domain = models.ForeignKey('assets.Domain', null=True, blank=True, related_name='applications', verbose_name=_("Domain"), on_delete=models.SET_NULL)
category = models.CharField(max_length=16, choices=Category.choices, verbose_name=_('Category'))
type = models.CharField(max_length=16, choices=Category.get_all_type_choices(), verbose_name=_('Type'))
attrs = JSONField()
attrs = models.JSONField()
comment = models.TextField(
max_length=128, default='', blank=True, verbose_name=_('Comment')
)

View File

@@ -27,10 +27,8 @@ class ApplicationSerializer(BulkOrgResourceModelSerializer):
]
def create(self, validated_data):
attrs = validated_data.pop('attrs', {})
validated_data['attrs'] = validated_data.pop('attrs', {})
instance = super().create(validated_data)
instance.attrs = attrs
instance.save()
return instance
def update(self, instance, validated_data):

View File

@@ -12,9 +12,8 @@ from .. import models
class DBAttrsSerializer(serializers.Serializer):
host = serializers.CharField(max_length=128, label=_('Host'))
port = serializers.IntegerField(label=_('Port'))
database = serializers.CharField(
max_length=128, required=False, allow_blank=True, allow_null=True, label=_('Database')
)
# 添加allow_null=True兼容之前数据库中database字段为None的情况
database = serializers.CharField(max_length=128, required=True, allow_null=True, label=_('Database'))
class MySQLAttrsSerializer(DBAttrsSerializer):

View File

@@ -115,7 +115,6 @@ class AssetTaskCreateApi(generics.CreateAPIView):
class AssetGatewayListApi(generics.ListAPIView):
permission_classes = (IsOrgAdminOrAppUser,)
serializer_class = serializers.GatewayWithAuthSerializer
model = Asset
def get_queryset(self):
asset_id = self.kwargs.get('pk')

View File

@@ -69,6 +69,7 @@ class SerializeToTreeNodeMixin:
'ip': asset.ip,
'protocols': asset.protocols_as_list,
'platform': asset.platform_base,
'org_name': asset.org_name
},
}
}

View File

@@ -5,11 +5,13 @@ from collections import namedtuple, defaultdict
from rest_framework import status
from rest_framework.serializers import ValidationError
from rest_framework.response import Response
from rest_framework.decorators import action
from django.utils.translation import ugettext_lazy as _
from django.shortcuts import get_object_or_404, Http404
from django.utils.decorators import method_decorator
from django.db.models.signals import m2m_changed
from common.const.http import POST
from common.exceptions import SomeoneIsDoingThis
from common.const.signals import PRE_REMOVE, POST_REMOVE
from assets.models import Asset
@@ -19,6 +21,8 @@ from common.const.distributed_lock_key import UPDATE_NODE_TREE_LOCK_KEY
from orgs.mixins.api import OrgModelViewSet
from orgs.mixins import generics
from orgs.lock import org_level_transaction_lock
from orgs.utils import current_org
from assets.tasks import check_node_assets_amount_task
from ..hands import IsOrgAdmin
from ..models import Node
from ..tasks import (
@@ -46,6 +50,11 @@ class NodeViewSet(OrgModelViewSet):
permission_classes = (IsOrgAdmin,)
serializer_class = serializers.NodeSerializer
@action(methods=[POST], detail=False, url_name='launch-check-assets-amount-task')
def launch_check_assets_amount_task(self, request):
task = check_node_assets_amount_task.delay(current_org.id)
return Response(data={'task': task.id})
# 仅支持根节点指直接创建子节点下的节点需要通过children接口创建
def perform_create(self, serializer):
child_key = Node.org_root().get_next_child_key()
@@ -61,6 +70,9 @@ class NodeViewSet(OrgModelViewSet):
def destroy(self, request, *args, **kwargs):
node = self.get_object()
if node.is_org_root():
error = _("You can't delete the root node ({})".format(node.value))
return Response(data={'error': error}, status=status.HTTP_403_FORBIDDEN)
if node.has_children_or_has_assets():
error = _("Deletion failed and the node contains children or assets")
return Response(data={'error': error}, status=status.HTTP_403_FORBIDDEN)
@@ -173,7 +185,7 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
return []
assets = self.instance.get_assets().only(
"id", "hostname", "ip", "os",
"org_id", "protocols",
"org_id", "protocols", "is_active"
)
return self.serialize_assets(assets, self.instance.key)
@@ -201,10 +213,8 @@ class NodeAddChildrenApi(generics.UpdateAPIView):
def put(self, request, *args, **kwargs):
instance = self.get_object()
nodes_id = request.data.get("nodes")
children = [get_object_or_none(Node, id=pk) for pk in nodes_id]
children = Node.objects.filter(id__in=nodes_id)
for node in children:
if not node:
continue
node.parent = instance
return Response("OK")

View File

@@ -3,7 +3,8 @@ from django.shortcuts import get_object_or_404
from rest_framework.response import Response
from common.utils import get_logger
from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser, IsAppUser
from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser
from common.drf.filters import CustomFilter
from orgs.mixins.api import OrgBulkModelViewSet
from orgs.mixins import generics
from orgs.utils import tmp_to_org
@@ -12,7 +13,7 @@ from .. import serializers
from ..serializers import SystemUserWithAuthInfoSerializer
from ..tasks import (
push_system_user_to_assets_manual, test_system_user_connectivity_manual,
push_system_user_a_asset_manual,
push_system_user_to_assets
)
@@ -82,18 +83,18 @@ class SystemUserTaskApi(generics.CreateAPIView):
permission_classes = (IsOrgAdmin,)
serializer_class = serializers.SystemUserTaskSerializer
def do_push(self, system_user, asset=None):
if asset is None:
def do_push(self, system_user, assets_id=None):
if assets_id is None:
task = push_system_user_to_assets_manual.delay(system_user)
else:
username = self.request.query_params.get('username')
task = push_system_user_a_asset_manual.delay(
system_user, asset, username=username
task = push_system_user_to_assets.delay(
system_user.id, assets_id, username=username
)
return task
@staticmethod
def do_test(system_user, asset=None):
def do_test(system_user):
task = test_system_user_connectivity_manual.delay(system_user)
return task
@@ -104,11 +105,16 @@ class SystemUserTaskApi(generics.CreateAPIView):
def perform_create(self, serializer):
action = serializer.validated_data["action"]
asset = serializer.validated_data.get('asset')
assets = serializer.validated_data.get('assets') or []
system_user = self.get_object()
if action == 'push':
task = self.do_push(system_user, asset)
assets = [asset] if asset else assets
assets_id = [asset.id for asset in assets]
assets_id = assets_id if assets_id else None
task = self.do_push(system_user, assets_id)
else:
task = self.do_test(system_user, asset)
task = self.do_test(system_user)
data = getattr(serializer, '_data', {})
data["task"] = task.id
setattr(serializer, '_data', data)

View File

@@ -0,0 +1,72 @@
# Generated by Jiangjie.Bai on 2020-12-01 10:47
from django.db import migrations
from django.db.models import Q
default_node_value = 'Default' # Always
old_default_node_key = '0' # Version <= 1.4.3
new_default_node_key = '1' # Version >= 1.4.4
def compute_parent_key(key):
try:
return key[:key.rindex(':')]
except ValueError:
return ''
def migrate_default_node_key(apps, schema_editor):
""" 将已经存在的Default节点的key从0修改为1 """
# 1.4.3版本中Default节点的key为0
print('')
Node = apps.get_model('assets', 'Node')
Asset = apps.get_model('assets', 'Asset')
# key为0的节点
old_default_node = Node.objects.filter(key=old_default_node_key, value=default_node_value).first()
if not old_default_node:
print(f'Check old default node `key={old_default_node_key} value={default_node_value}` not exists')
return
print(f'Check old default node `key={old_default_node_key} value={default_node_value}` exists')
# key为1的节点
new_default_node = Node.objects.filter(key=new_default_node_key, value=default_node_value).first()
if new_default_node:
print(f'Check new default node `key={new_default_node_key} value={default_node_value}` exists')
all_assets = Asset.objects.filter(
Q(nodes__key__startswith=f'{new_default_node_key}:') | Q(nodes__key=new_default_node_key)
).distinct()
if all_assets:
print(f'Check new default node has assets (count: {len(all_assets)})')
return
all_children = Node.objects.filter(key__startswith=f'{new_default_node_key}:')
if all_children:
print(f'Check new default node has children nodes (count: {len(all_children)})')
return
print(f'Check new default node not has assets and children nodes, delete it.')
new_default_node.delete()
# 执行修改
print(f'Modify old default node `key` from `{old_default_node_key}` to `{new_default_node_key}`')
nodes = Node.objects.filter(
Q(key__istartswith=f'{old_default_node_key}:') | Q(key=old_default_node_key)
)
for node in nodes:
old_key = node.key
key_list = old_key.split(':', maxsplit=1)
key_list[0] = new_default_node_key
new_key = ':'.join(key_list)
node.key = new_key
node.parent_key = compute_parent_key(node.key)
# 批量更新
print(f'Bulk update nodes `key` and `parent_key`, (count: {len(nodes)})')
Node.objects.bulk_update(nodes, ['key', 'parent_key'])
class Migration(migrations.Migration):
dependencies = [
('assets', '0062_auto_20201117_1938'),
]
operations = [
migrations.RunPython(migrate_default_node_key)
]

View File

@@ -0,0 +1,17 @@
# Generated by Django 3.1 on 2020-12-03 03:00
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('assets', '0063_migrate_default_node_key'),
]
operations = [
migrations.AlterModelOptions(
name='node',
options={'ordering': ['parent_key', 'value'], 'verbose_name': 'Node'},
),
]

View File

@@ -103,7 +103,7 @@ class FamilyMixin:
if value is None:
value = child_key
child = self.__class__.objects.create(
id=_id, key=child_key, value=value, parent_key=self.key,
id=_id, key=child_key, value=value
)
return child
@@ -354,7 +354,8 @@ class SomeNodesMixin:
def org_root(cls):
root = cls.objects.filter(parent_key='')\
.filter(key__regex=r'^[0-9]+$')\
.exclude(key__startswith='-')
.exclude(key__startswith='-')\
.order_by('key')
if root:
return root[0]
else:
@@ -411,7 +412,7 @@ class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
class Meta:
verbose_name = _("Node")
ordering = ['value']
ordering = ['parent_key', 'value']
def __str__(self):
return self.full_value
@@ -490,10 +491,15 @@ class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
sort_key_func = lambda n: [int(i) for i in n.key.split(':')]
nodes_sorted = sorted(list(nodes), key=sort_key_func)
nodes_mapper = {n.key: n for n in nodes_sorted}
if not self.is_org_root():
# 如果是org_root那么parent_key为'', parent为自己所以这种情况不处理
# 更新自己时自己的parent_key获取不到
nodes_mapper.update({self.parent_key: self.parent})
for node in nodes_sorted:
parent = nodes_mapper.get(node.parent_key)
if not parent:
logger.error(f'Node parent node in mapper: {node.parent_key} {node.value}')
if node.parent_key:
logger.error(f'Node parent node in mapper: {node.parent_key} {node.value}')
continue
node.full_value = parent.full_value + '/' + node.value
self.__class__.objects.bulk_update(nodes, ['full_value'])

View File

@@ -98,9 +98,6 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
fields_as = list(annotates_fields.keys())
fields = fields_small + fields_fk + fields_m2m + fields_as
read_only_fields = [
'vendor', 'model', 'sn', 'cpu_model', 'cpu_count',
'cpu_cores', 'cpu_vcpus', 'memory', 'disk_total', 'disk_info',
'os', 'os_version', 'os_arch', 'hostname_raw',
'created_by', 'date_created',
] + fields_as

View File

@@ -257,4 +257,8 @@ class SystemUserTaskSerializer(serializers.Serializer):
asset = serializers.PrimaryKeyRelatedField(
queryset=Asset.objects, allow_null=True, required=False, write_only=True
)
assets = serializers.PrimaryKeyRelatedField(
queryset=Asset.objects, allow_null=True, required=False, write_only=True,
many=True
)
task = serializers.CharField(read_only=True)

View File

@@ -4,7 +4,7 @@ from operator import add, sub
from assets.utils import is_asset_exists_in_node
from django.db.models.signals import (
post_save, m2m_changed, pre_delete, post_delete
post_save, m2m_changed, pre_delete, post_delete, pre_save
)
from django.db.models import Q, F
from django.dispatch import receiver
@@ -37,6 +37,11 @@ def test_asset_conn_on_created(asset):
test_asset_connectivity_util.delay([asset])
@receiver(pre_save, sender=Node)
def on_node_pre_save(sender, instance: Node, **kwargs):
instance.parent_key = instance.compute_parent_key()
@receiver(post_save, sender=Asset)
@on_transaction_commit
def on_asset_created_or_update(sender, instance=None, created=False, **kwargs):
@@ -73,6 +78,7 @@ def on_system_user_update(instance: SystemUser, created, **kwargs):
@receiver(m2m_changed, sender=SystemUser.assets.through)
@on_transaction_commit
def on_system_user_assets_change(instance, action, model, pk_set, **kwargs):
"""
当系统用户和资产关系发生变化时,应该重新推送系统用户到新添加的资产中
@@ -91,25 +97,29 @@ def on_system_user_assets_change(instance, action, model, pk_set, **kwargs):
@receiver(m2m_changed, sender=SystemUser.users.through)
def on_system_user_users_change(sender, instance=None, action='', model=None, pk_set=None, **kwargs):
@on_transaction_commit
def on_system_user_users_change(sender, instance: SystemUser, action, model, pk_set, reverse, **kwargs):
"""
当系统用户和用户关系发生变化时,应该重新推送系统用户资产中
"""
if action != POST_ADD:
return
if reverse:
raise M2MReverseNotAllowed
if not instance.username_same_with_user:
return
logger.debug("System user users change signal recv: {}".format(instance))
queryset = model.objects.filter(pk__in=pk_set)
if model == SystemUser:
system_users = queryset
else:
system_users = [instance]
for s in system_users:
push_system_user_to_assets_manual.delay(s)
usernames = model.objects.filter(pk__in=pk_set).values_list('username', flat=True)
for username in usernames:
push_system_user_to_assets_manual.delay(instance, username)
@receiver(m2m_changed, sender=SystemUser.nodes.through)
@on_transaction_commit
def on_system_user_nodes_change(sender, instance=None, action=None, model=None, pk_set=None, **kwargs):
"""
当系统用户和节点关系发生变化时,应该将节点下资产关联到新的系统用户上

View File

@@ -1,14 +1,27 @@
from celery import shared_task
from django.utils.translation import gettext_lazy as _
from orgs.models import Organization
from orgs.utils import tmp_to_org
from ops.celery.decorator import register_as_period_task
from assets.utils import check_node_assets_amount
from common.utils.lock import AcquireFailed
from common.utils import get_logger
from common.utils.timezone import now
logger = get_logger(__file__)
@shared_task()
def check_node_assets_amount_celery_task():
logger.info(f'>>> {now()} begin check_node_assets_amount_celery_task ...')
check_node_assets_amount()
logger.info(f'>>> {now()} end check_node_assets_amount_celery_task ...')
@shared_task(queue='celery_heavy_tasks')
def check_node_assets_amount_task(org_id=Organization.ROOT_ID):
try:
with tmp_to_org(Organization.get_instance(org_id)):
check_node_assets_amount()
except AcquireFailed:
logger.error(_('The task of self-checking is already running and cannot be started repeatedly'))
@register_as_period_task(crontab='0 2 * * *')
@shared_task(queue='celery_heavy_tasks')
def check_node_assets_amount_period_task():
check_node_assets_amount_task()

View File

@@ -2,13 +2,13 @@
from itertools import groupby
from celery import shared_task
from common.db.utils import get_object_if_need, get_objects_if_need, get_objects
from common.db.utils import get_object_if_need, get_objects
from django.utils.translation import ugettext as _
from django.db.models import Empty
from common.utils import encrypt_password, get_logger
from assets.models import SystemUser, Asset
from orgs.utils import org_aware_func
from assets.models import SystemUser, Asset, AuthBook
from orgs.utils import org_aware_func, tmp_to_root_org
from . import const
from .utils import clean_ansible_task_hosts, group_asset_by_platform
@@ -190,15 +190,12 @@ def get_push_system_user_tasks(system_user, platform="unixlike", username=None):
@org_aware_func("system_user")
def push_system_user_util(system_user, assets, task_name, username=None):
from ops.utils import update_or_create_ansible_task
hosts = clean_ansible_task_hosts(assets, system_user=system_user)
if not hosts:
assets = clean_ansible_task_hosts(assets, system_user=system_user)
if not assets:
return {}
platform_hosts_map = {}
hosts_sorted = sorted(hosts, key=group_asset_by_platform)
platform_hosts = groupby(hosts_sorted, key=group_asset_by_platform)
for i in platform_hosts:
platform_hosts_map[i[0]] = list(i[1])
assets_sorted = sorted(assets, key=group_asset_by_platform)
platform_hosts = groupby(assets_sorted, key=group_asset_by_platform)
def run_task(_tasks, _hosts):
if not _tasks:
@@ -209,27 +206,59 @@ def push_system_user_util(system_user, assets, task_name, username=None):
)
task.run()
for platform, _hosts in platform_hosts_map.items():
if not _hosts:
if system_user.username_same_with_user:
if username is None:
# 动态系统用户,但是没有指定 username
usernames = list(system_user.users.all().values_list('username', flat=True).distinct())
else:
usernames = [username]
else:
# 非动态系统用户指定 username 无效
assert username is None, 'Only Dynamic user can assign `username`'
usernames = [system_user.username]
for platform, _assets in platform_hosts:
_assets = list(_assets)
if not _assets:
continue
print(_("Start push system user for platform: [{}]").format(platform))
print(_("Hosts count: {}").format(len(_hosts)))
print(_("Hosts count: {}").format(len(_assets)))
# 如果没有特殊密码设置,就不需要单独推送某台机器了
if not system_user.has_special_auth(username=username):
logger.debug("System user not has special auth")
tasks = get_push_system_user_tasks(system_user, platform, username=username)
run_task(tasks, _hosts)
continue
id_asset_map = {_asset.id: _asset for _asset in _assets}
assets_id = id_asset_map.keys()
no_special_auth = []
special_auth_set = set()
for _host in _hosts:
system_user.load_asset_special_auth(_host, username=username)
tasks = get_push_system_user_tasks(system_user, platform, username=username)
run_task(tasks, [_host])
auth_books = AuthBook.objects.filter(username__in=usernames, asset_id__in=assets_id)
for auth_book in auth_books:
special_auth_set.add((auth_book.username, auth_book.asset_id))
for _username in usernames:
no_special_assets = []
for asset_id in assets_id:
if (_username, asset_id) not in special_auth_set:
no_special_assets.append(id_asset_map[asset_id])
if no_special_assets:
no_special_auth.append((_username, no_special_assets))
for _username, no_special_assets in no_special_auth:
tasks = get_push_system_user_tasks(system_user, platform, username=_username)
run_task(tasks, no_special_assets)
for auth_book in auth_books:
system_user._merge_auth(auth_book)
tasks = get_push_system_user_tasks(system_user, platform, username=auth_book.username)
asset = id_asset_map[auth_book.asset_id]
run_task(tasks, [asset])
@shared_task(queue="ansible")
@tmp_to_root_org()
def push_system_user_to_assets_manual(system_user, username=None):
"""
将系统用户推送到与它关联的所有资产上
"""
system_user = get_object_if_need(SystemUser, system_user)
assets = system_user.get_related_assets()
task_name = _("Push system users to assets: {}").format(system_user.name)
@@ -237,7 +266,11 @@ def push_system_user_to_assets_manual(system_user, username=None):
@shared_task(queue="ansible")
@tmp_to_root_org()
def push_system_user_a_asset_manual(system_user, asset, username=None):
"""
将系统用户推送到一个资产上
"""
if username is None:
username = system_user.username
task_name = _("Push system users to asset: {}({}) => {}").format(
@@ -247,10 +280,15 @@ def push_system_user_a_asset_manual(system_user, asset, username=None):
@shared_task(queue="ansible")
@tmp_to_root_org()
def push_system_user_to_assets(system_user_id, assets_id, username=None):
"""
推送系统用户到指定的若干资产上
"""
system_user = SystemUser.objects.get(id=system_user_id)
assets = get_objects(Asset, assets_id)
task_name = _("Push system users to assets: {}").format(system_user.name)
return push_system_user_util(system_user, assets, task_name, username=username)
# @shared_task

View File

@@ -1,8 +1,11 @@
# ~*~ coding: utf-8 ~*~
#
import time
from django.db.models import Q
from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none
from common.utils.lock import DistributedLock
from common.http import is_true
from .models import Asset, Node
@@ -10,17 +13,21 @@ from .models import Asset, Node
logger = get_logger(__file__)
@DistributedLock(name="assets.node.check_node_assets_amount", blocking=False)
def check_node_assets_amount():
for node in Node.objects.all():
logger.info(f'Check node assets amount: {node}')
assets_amount = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes=node)
).distinct().count()
if node.assets_amount != assets_amount:
print(f'>>> <Node:{node.key}> wrong assets amount '
f'{node.assets_amount} right is {assets_amount}')
logger.warn(f'Node wrong assets amount <Node:{node.key}> '
f'{node.assets_amount} right is {assets_amount}')
node.assets_amount = assets_amount
node.save()
# 防止自检程序给数据库的压力太大
time.sleep(0.1)
def is_asset_exists_in_node(asset_pk, node_key):

View File

@@ -0,0 +1,18 @@
# Generated by Django 3.1 on 2020-12-09 03:03
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('audits', '0010_auto_20200811_1122'),
]
operations = [
migrations.AddField(
model_name='userloginlog',
name='backend',
field=models.CharField(default='', max_length=32, verbose_name='Authentication backend'),
),
]

View File

@@ -105,6 +105,7 @@ class UserLoginLog(models.Model):
reason = models.CharField(default='', max_length=128, blank=True, verbose_name=_('Reason'))
status = models.BooleanField(max_length=2, default=True, choices=STATUS_CHOICE, verbose_name=_('Status'))
datetime = models.DateTimeField(default=timezone.now, verbose_name=_('Date login'))
backend = models.CharField(max_length=32, default='', verbose_name=_('Authentication backend'))
@classmethod
def get_login_logs(cls, date_from=None, date_to=None, user=None, keyword=None):

View File

@@ -31,7 +31,8 @@ class UserLoginLogSerializer(serializers.ModelSerializer):
model = models.UserLoginLog
fields = (
'id', 'username', 'type', 'type_display', 'ip', 'city', 'user_agent',
'mfa', 'reason', 'status', 'status_display', 'datetime', 'mfa_display'
'mfa', 'reason', 'status', 'status_display', 'datetime', 'mfa_display',
'backend'
)
extra_kwargs = {
"user_agent": {'label': _('User agent')}

View File

@@ -5,6 +5,8 @@ from django.db.models.signals import post_save, post_delete
from django.dispatch import receiver
from django.db import transaction
from django.utils import timezone
from django.contrib.auth import BACKEND_SESSION_KEY
from django.utils.translation import ugettext_lazy as _
from rest_framework.renderers import JSONRenderer
from rest_framework.request import Request
@@ -32,6 +34,19 @@ MODELS_NEED_RECORD = (
)
LOGIN_BACKEND = {
'PublicKeyAuthBackend': _('SSH Key'),
'RadiusBackend': User.Source.radius.label,
'RadiusRealmBackend': User.Source.radius.label,
'LDAPAuthorizationBackend': User.Source.ldap.label,
'ModelBackend': _('Password'),
'SSOAuthentication': _('SSO'),
'CASBackend': User.Source.cas.label,
'OIDCAuthCodeBackend': User.Source.openid.label,
'OIDCAuthPasswordBackend': User.Source.openid.label,
}
def create_operate_log(action, sender, resource):
user = current_request.user if current_request else None
if not user or not user.is_authenticated:
@@ -109,6 +124,16 @@ def on_audits_log_create(sender, instance=None, **kwargs):
sys_logger.info(msg)
def get_login_backend(request):
backend = request.session.get(BACKEND_SESSION_KEY, '')
backend = backend.rsplit('.', maxsplit=1)[-1]
if backend in LOGIN_BACKEND:
return LOGIN_BACKEND[backend]
else:
logger.warn(f'LOGIN_BACKEND_NOT_FOUND: {backend}')
return ''
def generate_data(username, request):
user_agent = request.META.get('HTTP_USER_AGENT', '')
login_ip = get_request_ip(request) or '0.0.0.0'
@@ -122,7 +147,8 @@ def generate_data(username, request):
'ip': login_ip,
'type': login_type,
'user_agent': user_agent,
'datetime': timezone.now()
'datetime': timezone.now(),
'backend': get_login_backend(request)
}
return data

View File

@@ -6,7 +6,7 @@ import time
from django.core.cache import cache
from django.utils.translation import ugettext as _
from django.utils.six import text_type
from six import text_type
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
from rest_framework import HTTP_HEADER_ENCODING

View File

@@ -82,6 +82,12 @@ class LDAPAuthorizationBackend(LDAPBackend):
class LDAPUser(_LDAPUser):
def _search_for_user_dn_from_ldap_util(self):
from settings.utils import LDAPServerUtil
util = LDAPServerUtil()
user_dn = util.search_for_user_dn(self._username)
return user_dn
def _search_for_user_dn(self):
"""
This method was overridden because the AUTH_LDAP_USER_SEARCH
@@ -107,7 +113,14 @@ class LDAPUser(_LDAPUser):
if results is not None and len(results) == 1:
(user_dn, self._user_attrs) = next(iter(results))
else:
user_dn = None
# 解决直接配置DC域用户认证失败的问题(库不能从整棵树中搜索)
user_dn = self._search_for_user_dn_from_ldap_util()
if user_dn is None:
self._user_dn = None
self._user_attrs = None
else:
self._user_dn = user_dn
self._user_attrs = self._load_user_attrs()
return user_dn

View File

@@ -23,7 +23,7 @@ class CreateUserMixin:
email_suffix = settings.EMAIL_SUFFIX
email = '{}@{}'.format(username, email_suffix)
user = User(username=username, name=username, email=email)
user.source = user.SOURCE_RADIUS
user.source = user.Source.radius.value
user.save()
return user

View File

@@ -218,5 +218,14 @@ class PasswdTooSimple(JMSException):
default_detail = _('Your password is too simple, please change it for security')
def __init__(self, url, *args, **kwargs):
super(PasswdTooSimple, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.url = url
class PasswordRequireResetError(JMSException):
default_code = 'passwd_has_expired'
default_detail = _('Your password has expired, please reset before logging in')
def __init__(self, url, *args, **kwargs):
super().__init__(*args, **kwargs)
self.url = url

View File

@@ -4,7 +4,7 @@
from django import forms
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from captcha.fields import CaptchaField
from captcha.fields import CaptchaField, CaptchaTextInput
class UserLoginForm(forms.Form):
@@ -26,8 +26,12 @@ class UserCheckOtpCodeForm(forms.Form):
otp_code = forms.CharField(label=_('MFA code'), max_length=6)
class CustomCaptchaTextInput(CaptchaTextInput):
template_name = 'authentication/_captcha_field.html'
class CaptchaMixin(forms.Form):
captcha = CaptchaField()
captcha = CaptchaField(widget=CustomCaptchaTextInput)
class ChallengeMixin(forms.Form):

View File

@@ -110,9 +110,8 @@ class AuthMixin:
raise CredentialError(error=errors.reason_user_inactive)
elif not user.is_active:
raise CredentialError(error=errors.reason_user_inactive)
elif user.password_has_expired:
raise CredentialError(error=errors.reason_password_expired)
self._check_password_require_reset_or_not(user)
self._check_passwd_is_too_simple(user, password)
clean_failed_count(username, ip)
@@ -123,20 +122,34 @@ class AuthMixin:
return user
@classmethod
def _check_passwd_is_too_simple(cls, user, password):
def generate_reset_password_url_with_flash_msg(cls, user: User, flash_view_name):
reset_passwd_url = reverse('authentication:reset-password')
query_str = urlencode({
'token': user.generate_reset_token()
})
reset_passwd_url = f'{reset_passwd_url}?{query_str}'
flash_page_url = reverse(flash_view_name)
query_str = urlencode({
'redirect_url': reset_passwd_url
})
return f'{flash_page_url}?{query_str}'
@classmethod
def _check_passwd_is_too_simple(cls, user: User, password):
if user.is_superuser and password == 'admin':
reset_passwd_url = reverse('authentication:reset-password')
query_str = urlencode({
'token': user.generate_reset_token()
})
reset_passwd_url = f'{reset_passwd_url}?{query_str}'
url = cls.generate_reset_password_url_with_flash_msg(
user, 'authentication:passwd-too-simple-flash-msg'
)
raise errors.PasswdTooSimple(url)
flash_page_url = reverse('authentication:passwd-too-simple-flash-msg')
query_str = urlencode({
'redirect_url': reset_passwd_url
})
raise errors.PasswdTooSimple(f'{flash_page_url}?{query_str}')
@classmethod
def _check_password_require_reset_or_not(cls, user: User):
if user.password_has_expired:
url = cls.generate_reset_password_url_with_flash_msg(
user, 'authentication:passwd-has-expired-flash-msg'
)
raise errors.PasswordRequireResetError(url)
def check_user_auth_if_need(self, decrypt_passwd=False):
request = self.request

View File

@@ -1,11 +1,9 @@
import uuid
from functools import partial
from django.utils import timezone
from django.utils.translation import ugettext_lazy as _, ugettext as __
from rest_framework.authtoken.models import Token
from django.conf import settings
from django.utils.crypto import get_random_string
from common.db import models
from common.mixins.models import CommonModelMixin

View File

@@ -1,5 +1,6 @@
from importlib import import_module
from django.contrib.auth import BACKEND_SESSION_KEY
from django.conf import settings
from django.contrib.auth import user_logged_in
from django.core.cache import cache
@@ -24,14 +25,17 @@ def on_user_auth_login_success(sender, user, request, **kwargs):
@receiver(openid_user_login_success)
def on_oidc_user_login_success(sender, request, user, **kwargs):
request.session[BACKEND_SESSION_KEY] = 'OIDCAuthCodeBackend'
post_auth_success.send(sender, user=user, request=request)
@receiver(openid_user_login_failed)
def on_oidc_user_login_failed(sender, username, request, reason, **kwargs):
request.session[BACKEND_SESSION_KEY] = 'OIDCAuthCodeBackend'
post_auth_failed.send(sender, username=username, request=request, reason=reason)
@receiver(cas_user_authenticated)
def on_cas_user_login_success(sender, request, user, **kwargs):
post_auth_success.send(sender, user=user, request=request)
request.session[BACKEND_SESSION_KEY] = 'CASBackend'
post_auth_success.send(sender, user=user, request=request)

View File

@@ -0,0 +1,29 @@
{% load i18n %}
{% spaceless %}
<img src="{{ image }}" alt="captcha" class="captcha" />
<div class="row" style="padding-bottom: 10px">
<div class="col-sm-6">
<div class="input-group-prepend">
{% if audio %}
<a title="{% trans "Play CAPTCHA as audio file" %}" href="{{ audio }}">
{% endif %}
</div>
{% include "django/forms/widgets/multiwidget.html" %}
</div>
</div>
<script>
var placeholder = '{% trans "Captcha" %}'
function refresh_captcha() {
$.getJSON("{% url "captcha-refresh" %}",
function (result) {
$('.captcha').attr('src', result['image_url']);
$('#id_captcha_0').val(result['key'])
})
}
$(document).ready(function () {
$('.captcha').click(refresh_captcha)
$('#id_captcha_1').addClass('form-control').attr('placeholder', placeholder)
})
</script>
{% endspaceless %}

View File

@@ -22,6 +22,7 @@ urlpatterns = [
name='forgot-password-sendmail-success'),
path('password/reset/', users_view.UserResetPasswordView.as_view(), name='reset-password'),
path('password/too-simple-flash-msg/', views.FlashPasswdTooSimpleMsgView.as_view(), name='passwd-too-simple-flash-msg'),
path('password/has-expired-msg/', views.FlashPasswdHasExpiredMsgView.as_view(), name='passwd-has-expired-flash-msg'),
path('password/reset/success/', users_view.UserResetPasswordSuccessView.as_view(), name='reset-password-success'),
path('password/verify/', users_view.UserVerifyPasswordView.as_view(), name='user-verify-password'),

View File

@@ -32,7 +32,7 @@ from ..forms import get_user_login_form_cls
__all__ = [
'UserLoginView', 'UserLogoutView',
'UserLoginGuardView', 'UserLoginWaitConfirmView',
'FlashPasswdTooSimpleMsgView',
'FlashPasswdTooSimpleMsgView', 'FlashPasswdHasExpiredMsgView'
]
@@ -96,7 +96,7 @@ class UserLoginView(mixins.AuthMixin, FormView):
new_form._errors = form.errors
context = self.get_context_data(form=new_form)
return self.render_to_response(context)
except errors.PasswdTooSimple as e:
except (errors.PasswdTooSimple, errors.PasswordRequireResetError) as e:
return redirect(e.url)
self.clear_rsa_key()
return self.redirect_to_guard_view()
@@ -250,3 +250,18 @@ class FlashPasswdTooSimpleMsgView(TemplateView):
'auto_redirect': True,
}
return self.render_to_response(context)
@method_decorator(never_cache, name='dispatch')
class FlashPasswdHasExpiredMsgView(TemplateView):
template_name = 'flash_message_standalone.html'
def get(self, request, *args, **kwargs):
context = {
'title': _('Please change your password'),
'messages': _('Your password has expired, please reset before logging in'),
'interval': 5,
'redirect_url': request.GET.get('redirect_url'),
'auto_redirect': True,
}
return self.render_to_response(context)

View File

@@ -1,11 +1,12 @@
from django.core.exceptions import PermissionDenied, ObjectDoesNotExist as DJObjectDoesNotExist
from django.http import Http404
from django.utils.translation import gettext
from django.db.models.deletion import ProtectedError
from rest_framework import exceptions
from rest_framework.views import set_rollback
from rest_framework.response import Response
from common.exceptions import JMSObjectDoesNotExist
from common.exceptions import JMSObjectDoesNotExist, ReferencedByOthers
from logging import getLogger
logger = getLogger('drf_exception')
@@ -31,6 +32,8 @@ def common_exception_handler(exc, context):
exc = exceptions.PermissionDenied()
elif isinstance(exc, DJObjectDoesNotExist):
exc = JMSObjectDoesNotExist(object_name=extract_object_name(exc, 0))
elif isinstance(exc, ProtectedError):
exc = ReferencedByOthers()
if isinstance(exc, exceptions.APIException):
headers = {}

View File

@@ -1 +1,2 @@
from .csv import *
from .csv import *
from .excel import *

View File

@@ -0,0 +1,132 @@
import abc
import json
import codecs
from django.utils.translation import ugettext_lazy as _
from rest_framework.parsers import BaseParser
from rest_framework import status
from rest_framework.exceptions import ParseError, APIException
from common.utils import get_logger
logger = get_logger(__file__)
class FileContentOverflowedError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_code = 'file_content_overflowed'
default_detail = _('The file content overflowed (The maximum length `{}` bytes)')
class BaseFileParser(BaseParser):
FILE_CONTENT_MAX_LENGTH = 1024 * 1024 * 10
serializer_cls = None
def check_content_length(self, meta):
content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)))
if content_length > self.FILE_CONTENT_MAX_LENGTH:
msg = FileContentOverflowedError.default_detail.format(self.FILE_CONTENT_MAX_LENGTH)
logger.error(msg)
raise FileContentOverflowedError(msg)
@staticmethod
def get_stream_data(stream):
stream_data = stream.read()
stream_data = stream_data.strip(codecs.BOM_UTF8)
return stream_data
@abc.abstractmethod
def generate_rows(self, stream_data):
raise NotImplemented
def get_column_titles(self, rows):
return next(rows)
def convert_to_field_names(self, column_titles):
fields_map = {}
fields = self.serializer_cls().fields
fields_map.update({v.label: k for k, v in fields.items()})
fields_map.update({k: k for k, _ in fields.items()})
field_names = [
fields_map.get(column_title.strip('*'), '')
for column_title in column_titles
]
return field_names
@staticmethod
def _replace_chinese_quote(s):
trans_table = str.maketrans({
'': '"',
'': '"',
'': '"',
'': '"',
'\'': '"'
})
return s.translate(trans_table)
@classmethod
def process_row(cls, row):
"""
构建json数据前的行处理
"""
new_row = []
for col in row:
# 转换中文引号
col = cls._replace_chinese_quote(col)
# 列表/字典转换
if isinstance(col, str) and (
(col.startswith('[') and col.endswith(']'))
or
(col.startswith("{") and col.endswith("}"))
):
col = json.loads(col)
new_row.append(col)
return new_row
@staticmethod
def process_row_data(row_data):
"""
构建json数据后的行数据处理
"""
new_row_data = {}
for k, v in row_data.items():
if isinstance(v, list) or isinstance(v, dict) or isinstance(v, str) and k.strip() and v.strip():
new_row_data[k] = v
return new_row_data
def generate_data(self, fields_name, rows):
data = []
for row in rows:
# 空行不处理
if not any(row):
continue
row = self.process_row(row)
row_data = dict(zip(fields_name, row))
row_data = self.process_row_data(row_data)
data.append(row_data)
return data
def parse(self, stream, media_type=None, parser_context=None):
parser_context = parser_context or {}
try:
view = parser_context['view']
meta = view.request.META
self.serializer_cls = view.get_serializer_class()
except Exception as e:
logger.debug(e, exc_info=True)
raise ParseError('The resource does not support imports!')
self.check_content_length(meta)
try:
stream_data = self.get_stream_data(stream)
rows = self.generate_rows(stream_data)
column_titles = self.get_column_titles(rows)
field_names = self.convert_to_field_names(column_titles)
data = self.generate_data(field_names, rows)
return data
except Exception as e:
logger.error(e, exc_info=True)
raise ParseError('Parse error! ({})'.format(self.media_type))

View File

@@ -1,32 +1,13 @@
# ~*~ coding: utf-8 ~*~
#
import json
import chardet
import codecs
import unicodecsv
from django.utils.translation import ugettext as _
from rest_framework.parsers import BaseParser
from rest_framework.exceptions import ParseError, APIException
from rest_framework import status
from common.utils import get_logger
logger = get_logger(__file__)
from .base import BaseFileParser
class CsvDataTooBig(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_code = 'csv_data_too_big'
default_detail = _('The max size of CSV is %d bytes')
class JMSCSVParser(BaseParser):
"""
Parses CSV file to serializer data
"""
CSV_UPLOAD_MAX_SIZE = 1024 * 1024 * 10
class CSVFileParser(BaseFileParser):
media_type = 'text/csv'
@@ -38,99 +19,10 @@ class JMSCSVParser(BaseParser):
for line in stream.splitlines():
yield line
@staticmethod
def _gen_rows(csv_data, charset='utf-8', **kwargs):
csv_reader = unicodecsv.reader(csv_data, encoding=charset, **kwargs)
def generate_rows(self, stream_data):
detect_result = chardet.detect(stream_data)
encoding = detect_result.get("encoding", "utf-8")
lines = self._universal_newlines(stream_data)
csv_reader = unicodecsv.reader(lines, encoding=encoding)
for row in csv_reader:
if not any(row): # 空行
continue
yield row
@staticmethod
def _get_fields_map(serializer_cls):
fields_map = {}
fields = serializer_cls().fields
fields_map.update({v.label: k for k, v in fields.items()})
fields_map.update({k: k for k, _ in fields.items()})
return fields_map
@staticmethod
def _replace_chinese_quot(str_):
trans_table = str.maketrans({
'': '"',
'': '"',
'': '"',
'': '"',
'\'': '"'
})
return str_.translate(trans_table)
@classmethod
def _process_row(cls, row):
"""
构建json数据前的行处理
"""
_row = []
for col in row:
# 列表转换
if isinstance(col, str) and col.startswith('[') and col.endswith(']'):
col = cls._replace_chinese_quot(col)
col = json.loads(col)
# 字典转换
if isinstance(col, str) and col.startswith("{") and col.endswith("}"):
col = cls._replace_chinese_quot(col)
col = json.loads(col)
_row.append(col)
return _row
@staticmethod
def _process_row_data(row_data):
"""
构建json数据后的行数据处理
"""
_row_data = {}
for k, v in row_data.items():
if isinstance(v, list) or isinstance(v, dict)\
or isinstance(v, str) and k.strip() and v.strip():
_row_data[k] = v
return _row_data
def parse(self, stream, media_type=None, parser_context=None):
parser_context = parser_context or {}
try:
view = parser_context['view']
meta = view.request.META
serializer_cls = view.get_serializer_class()
except Exception as e:
logger.debug(e, exc_info=True)
raise ParseError('The resource does not support imports!')
content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)))
if content_length > self.CSV_UPLOAD_MAX_SIZE:
msg = CsvDataTooBig.default_detail % self.CSV_UPLOAD_MAX_SIZE
logger.error(msg)
raise CsvDataTooBig(msg)
try:
stream_data = stream.read()
stream_data = stream_data.strip(codecs.BOM_UTF8)
detect_result = chardet.detect(stream_data)
encoding = detect_result.get("encoding", "utf-8")
binary = self._universal_newlines(stream_data)
rows = self._gen_rows(binary, charset=encoding)
header = next(rows)
fields_map = self._get_fields_map(serializer_cls)
header = [fields_map.get(name.strip('*'), '') for name in header]
data = []
for row in rows:
row = self._process_row(row)
row_data = dict(zip(header, row))
row_data = self._process_row_data(row_data)
data.append(row_data)
return data
except Exception as e:
logger.error(e, exc_info=True)
raise ParseError('CSV parse error!')

View File

@@ -0,0 +1,14 @@
import pyexcel
from .base import BaseFileParser
class ExcelFileParser(BaseFileParser):
media_type = 'text/xlsx'
def generate_rows(self, stream_data):
workbook = pyexcel.get_book(file_type='xlsx', file_content=stream_data)
# 默认获取第一个工作表sheet
sheet = workbook.sheet_by_index(0)
rows = sheet.rows()
return rows

View File

@@ -1,6 +1,7 @@
from rest_framework import renderers
from .csv import *
from .excel import *
class PassthroughRenderer(renderers.BaseRenderer):

View File

@@ -0,0 +1,134 @@
import abc
from datetime import datetime
from rest_framework.renderers import BaseRenderer
from rest_framework.utils import encoders, json
from common.utils import get_logger
logger = get_logger(__file__)
class BaseFileRenderer(BaseRenderer):
# 渲染模版标识, 导入、导出、更新模版: ['import', 'update', 'export']
template = 'export'
serializer = None
@staticmethod
def _check_validation_data(data):
detail_key = "detail"
if detail_key in data:
return False
return True
@staticmethod
def _json_format_response(response_data):
return json.dumps(response_data)
def set_response_disposition(self, response):
serializer = self.serializer
if response and hasattr(serializer, 'Meta') and hasattr(serializer.Meta, "model"):
filename_prefix = serializer.Meta.model.__name__.lower()
else:
filename_prefix = 'download'
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = "{}_{}.{}".format(filename_prefix, now, self.format)
disposition = 'attachment; filename="{}"'.format(filename)
response['Content-Disposition'] = disposition
def get_rendered_fields(self):
fields = self.serializer.fields
if self.template == 'import':
return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id']
elif self.template == 'update':
return [v for k, v in fields.items() if not v.read_only and k != "org_id"]
else:
return [v for k, v in fields.items() if not v.write_only and k != "org_id"]
@staticmethod
def get_column_titles(render_fields):
return [
'*{}'.format(field.label) if field.required else str(field.label)
for field in render_fields
]
def process_data(self, data):
results = data['results'] if 'results' in data else data
if isinstance(results, dict):
results = [results]
if self.template == 'import':
results = [results[0]] if results else results
else:
# 限制数据数量
results = results[:10000]
# 会将一些 UUID 字段转化为 string
results = json.loads(json.dumps(results, cls=encoders.JSONEncoder))
return results
@staticmethod
def generate_rows(data, render_fields):
for item in data:
row = []
for field in render_fields:
value = item.get(field.field_name)
value = str(value) if value else ''
row.append(value)
yield row
@abc.abstractmethod
def initial_writer(self):
raise NotImplementedError
def write_column_titles(self, column_titles):
self.write_row(column_titles)
def write_rows(self, rows):
for row in rows:
self.write_row(row)
@abc.abstractmethod
def write_row(self, row):
raise NotImplementedError
@abc.abstractmethod
def get_rendered_value(self):
raise NotImplementedError
def render(self, data, accepted_media_type=None, renderer_context=None):
if data is None:
return bytes()
if not self._check_validation_data(data):
return self._json_format_response(data)
try:
renderer_context = renderer_context or {}
request = renderer_context['request']
response = renderer_context['response']
view = renderer_context['view']
self.template = request.query_params.get('template', 'export')
self.serializer = view.get_serializer()
self.set_response_disposition(response)
except Exception as e:
logger.debug(e, exc_info=True)
value = 'The resource not support export!'.encode('utf-8')
return value
try:
rendered_fields = self.get_rendered_fields()
column_titles = self.get_column_titles(rendered_fields)
data = self.process_data(data)
rows = self.generate_rows(data, rendered_fields)
self.initial_writer()
self.write_column_titles(column_titles)
self.write_rows(rows)
value = self.get_rendered_value()
except Exception as e:
logger.debug(e, exc_info=True)
value = 'Render error! ({})'.format(self.media_type).encode('utf-8')
return value
return value

View File

@@ -1,83 +1,30 @@
# ~*~ coding: utf-8 ~*~
#
import unicodecsv
import codecs
from datetime import datetime
import unicodecsv
from six import BytesIO
from rest_framework.renderers import BaseRenderer
from rest_framework.utils import encoders, json
from common.utils import get_logger
logger = get_logger(__file__)
from .base import BaseFileRenderer
class JMSCSVRender(BaseRenderer):
class CSVFileRenderer(BaseFileRenderer):
media_type = 'text/csv'
format = 'csv'
@staticmethod
def _get_show_fields(fields, template):
if template == 'import':
return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id']
elif template == 'update':
return [v for k, v in fields.items() if not v.read_only and k != "org_id"]
else:
return [v for k, v in fields.items() if not v.write_only and k != "org_id"]
writer = None
buffer = None
@staticmethod
def _gen_table(data, fields):
data = data[:10000]
yield ['*{}'.format(f.label) if f.required else f.label for f in fields]
def initial_writer(self):
csv_buffer = BytesIO()
csv_buffer.write(codecs.BOM_UTF8)
csv_writer = unicodecsv.writer(csv_buffer, encoding='utf-8')
self.buffer = csv_buffer
self.writer = csv_writer
for item in data:
row = [item.get(f.field_name) for f in fields]
yield row
def set_response_disposition(self, serializer, context):
response = context.get('response')
if response and hasattr(serializer, 'Meta') and \
hasattr(serializer.Meta, "model"):
model_name = serializer.Meta.model.__name__.lower()
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = "{}_{}.csv".format(model_name, now)
disposition = 'attachment; filename="{}"'.format(filename)
response['Content-Disposition'] = disposition
def render(self, data, media_type=None, renderer_context=None):
renderer_context = renderer_context or {}
request = renderer_context['request']
template = request.query_params.get('template', 'export')
view = renderer_context['view']
if isinstance(data, dict):
data = data.get("results", [])
if template == 'import':
data = [data[0]] if data else data
data = json.loads(json.dumps(data, cls=encoders.JSONEncoder))
try:
serializer = view.get_serializer()
self.set_response_disposition(serializer, renderer_context)
except Exception as e:
logger.debug(e, exc_info=True)
value = 'The resource not support export!'.encode('utf-8')
else:
fields = serializer.fields
show_fields = self._get_show_fields(fields, template)
table = self._gen_table(data, show_fields)
csv_buffer = BytesIO()
csv_buffer.write(codecs.BOM_UTF8)
csv_writer = unicodecsv.writer(csv_buffer, encoding='utf-8')
for row in table:
csv_writer.writerow(row)
value = csv_buffer.getvalue()
def write_row(self, row):
self.writer.writerow(row)
def get_rendered_value(self):
value = self.buffer.getvalue()
return value

View File

@@ -0,0 +1,31 @@
from openpyxl import Workbook
from openpyxl.writer.excel import save_virtual_workbook
from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE
from .base import BaseFileRenderer
class ExcelFileRenderer(BaseFileRenderer):
media_type = "application/xlsx"
format = "xlsx"
wb = None
ws = None
row_count = 0
def initial_writer(self):
self.wb = Workbook()
self.ws = self.wb.active
def write_row(self, row):
self.row_count += 1
column_count = 0
for cell_value in row:
# 处理非法字符
column_count += 1
cell_value = ILLEGAL_CHARACTERS_RE.sub(r'', cell_value)
self.ws.cell(row=self.row_count, column=column_count, value=cell_value)
def get_rendered_value(self):
value = save_virtual_workbook(self.wb)
return value

View File

@@ -33,3 +33,9 @@ class Timeout(JMSException):
class M2MReverseNotAllowed(JMSException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('M2M reverse not allowed')
class ReferencedByOthers(JMSException):
status_code = status.HTTP_400_BAD_REQUEST
default_code = 'referenced_by_others'
default_detail = _('Is referenced by other objects and cannot be deleted')

View File

@@ -3,7 +3,7 @@
import json
from django import forms
from django.utils import six
import six
from django.core.exceptions import ValidationError
from django.utils.translation import ugettext as _
from ..utils import signer

View File

@@ -31,7 +31,7 @@ class JsonMixin:
def json_encode(data):
return json.dumps(data)
def from_db_value(self, value, expression, connection, context):
def from_db_value(self, value, expression, connection, context=None):
if value is None:
return value
return self.json_decode(value)
@@ -54,7 +54,7 @@ class JsonMixin:
class JsonTypeMixin(JsonMixin):
tp = dict
def from_db_value(self, value, expression, connection, context):
def from_db_value(self, value, expression, connection, context=None):
value = super().from_db_value(value, expression, connection, context)
if not isinstance(value, self.tp):
value = self.tp()
@@ -116,7 +116,7 @@ class EncryptMixin:
def decrypt_from_signer(self, value):
return signer.unsign(value) or ''
def from_db_value(self, value, expression, connection, context):
def from_db_value(self, value, expression, connection, context=None):
if value is None:
return value
value = force_text(value)

View File

@@ -2,7 +2,7 @@
#
from rest_framework import serializers
from django.utils import six
import six
__all__ = [

View File

@@ -41,7 +41,7 @@ def timesince(dt, since='', default="just now"):
3 days, 5 hours.
"""
if since is '':
if not since:
since = datetime.datetime.utcnow()
if since is None:

View File

@@ -0,0 +1,10 @@
import inspect
def copy_function_args(func, locals_dict: dict):
signature = inspect.signature(func)
keys = signature.parameters.keys()
kwargs = {}
for k in keys:
kwargs[k] = locals_dict.get(k)
return kwargs

55
apps/common/utils/lock.py Normal file
View File

@@ -0,0 +1,55 @@
from functools import wraps
from redis_lock import Lock as RedisLock
from redis import Redis
from common.utils import get_logger
from common.utils.inspect import copy_function_args
from apps.jumpserver.const import CONFIG
logger = get_logger(__file__)
class AcquireFailed(RuntimeError):
pass
class DistributedLock(RedisLock):
def __init__(self, name, blocking=True, expire=60*2, auto_renewal=True):
"""
使用 redis 构造的分布式锁
:param name:
锁的名字,要全局唯一
:param blocking:
该参数只在锁作为装饰器或者 `with` 时有效。
:param expire:
锁的过期时间,注意不一定是锁到这个时间就释放了,分两种情况
当 `auto_renewal=False` 时,锁会释放
当 `auto_renewal=True` 时,如果过期之前程序还没释放锁,我们会延长锁的存活时间。
这里的作用是防止程序意外终止没有释放锁,导致死锁。
"""
self.kwargs_copy = copy_function_args(self.__init__, locals())
redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
super().__init__(redis_client=redis, name=name, expire=expire, auto_renewal=auto_renewal)
self._blocking = blocking
def __enter__(self):
acquired = self.acquire(blocking=self._blocking)
if self._blocking and not acquired:
raise EnvironmentError("Lock wasn't acquired, but blocking=True")
if not acquired:
raise AcquireFailed
return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
self.release()
def __call__(self, func):
@wraps(func)
def inner(*args, **kwds):
# 要创建一个新的锁对象
with self.__class__(**self.kwargs_copy):
return func(*args, **kwds)
return inner

View File

@@ -14,9 +14,9 @@ alphanumeric = RegexValidator(r'^[0-9a-zA-Z_@\-\.]*$', _('Special char not allow
class ProjectUniqueValidator(UniqueTogetherValidator):
def __call__(self, attrs):
def __call__(self, attrs, serializer):
try:
super().__call__(attrs)
super().__call__(attrs, serializer)
except ValidationError as e:
errors = {}
for field in self.fields:

View File

@@ -2,13 +2,14 @@ from django.core.cache import cache
from django.utils import timezone
from django.utils.timesince import timesince
from django.db.models import Count, Max
from django.http.response import JsonResponse
from django.http.response import JsonResponse, HttpResponse
from rest_framework.views import APIView
from collections import Counter
from users.models import User
from assets.models import Asset
from terminal.models import Session
from terminal.utils import ComponentsPrometheusMetricsUtil
from orgs.utils import current_org
from common.permissions import IsOrgAdmin, IsOrgAuditor
from common.utils import lazyproperty
@@ -305,3 +306,11 @@ class IndexApi(TotalCountMixin, DatesLoginMetricMixin, APIView):
return JsonResponse(data, status=200)
class PrometheusMetricsApi(APIView):
permission_classes = ()
def get(self, request, *args, **kwargs):
util = ComponentsPrometheusMetricsUtil()
metrics_text = util.get_prometheus_metrics_text()
return HttpResponse(metrics_text, content_type='text/plain; version=0.0.4; charset=utf-8')

View File

@@ -16,7 +16,7 @@ import json
import yaml
from importlib import import_module
from django.urls import reverse_lazy
from django.contrib.staticfiles.templatetags.staticfiles import static
from django.templatetags.static import static
from urllib.parse import urljoin, urlparse
from django.utils.translation import ugettext_lazy as _

View File

@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
#
from django.contrib.staticfiles.templatetags.staticfiles import static
from django.templatetags.static import static
from django.conf import settings
from django.utils.translation import gettext_lazy as _

View File

@@ -64,6 +64,7 @@ INSTALLED_APPS = [
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'django.forms',
]

View File

@@ -11,14 +11,17 @@ REST_FRAMEWORK = {
),
'DEFAULT_RENDERER_CLASSES': (
'rest_framework.renderers.JSONRenderer',
'rest_framework.renderers.BrowsableAPIRenderer',
'common.drf.renders.JMSCSVRender',
# 'rest_framework.renderers.BrowsableAPIRenderer',
'common.drf.renders.CSVFileRenderer',
'common.drf.renders.ExcelFileRenderer',
),
'DEFAULT_PARSER_CLASSES': (
'rest_framework.parsers.JSONParser',
'rest_framework.parsers.FormParser',
'rest_framework.parsers.MultiPartParser',
'common.drf.parsers.JMSCSVParser',
'common.drf.parsers.CSVFileParser',
'common.drf.parsers.ExcelFileParser',
'rest_framework.parsers.FileUploadParser',
),
'DEFAULT_AUTHENTICATION_CLASSES': (
@@ -61,10 +64,10 @@ SWAGGER_SETTINGS = {
# Captcha settings, more see https://django-simple-captcha.readthedocs.io/en/latest/advanced.html
CAPTCHA_IMAGE_SIZE = (80, 33)
CAPTCHA_IMAGE_SIZE = (140, 34)
CAPTCHA_FOREGROUND_COLOR = '#001100'
CAPTCHA_NOISE_FUNCTIONS = ('captcha.helpers.noise_dots',)
CAPTCHA_TEST_MODE = CONFIG.CAPTCHA_TEST_MODE
CAPTCHA_CHALLENGE_FUNCT = 'captcha.helpers.math_challenge'
# Django bootstrap3 setting, more see http://django-bootstrap3.readthedocs.io/en/latest/settings.html
BOOTSTRAP3 = {

View File

@@ -5,6 +5,7 @@ from django.urls import path, include, re_path
from django.conf import settings
from django.conf.urls.static import static
from django.conf.urls.i18n import i18n_patterns
from django.http import HttpResponse
from django.views.i18n import JavaScriptCatalog
from . import views, api
@@ -23,6 +24,7 @@ api_v1 = [
path('common/', include('common.urls.api_urls', namespace='api-common')),
path('applications/', include('applications.urls.api_urls', namespace='api-applications')),
path('tickets/', include('tickets.urls.api_urls', namespace='api-tickets')),
path('prometheus/metrics/', api.PrometheusMetricsApi.as_view())
]
api_v2 = [
@@ -30,18 +32,13 @@ api_v2 = [
path('users/', include('users.urls.api_urls_v2', namespace='api-users-v2')),
]
app_view_patterns = [
path('auth/', include('authentication.urls.view_urls'), name='auth'),
path('ops/', include('ops.urls.view_urls'), name='ops'),
re_path(r'flower/(?P<path>.*)', views.celery_flower_view, name='flower-view'),
]
if settings.XPACK_ENABLED:
app_view_patterns.append(
path('xpack/', include('xpack.urls.view_urls', namespace='xpack'))
)
api_v1.append(
path('xpack/', include('xpack.urls.api_urls', namespace='api-xpack'))
)
@@ -53,7 +50,6 @@ apps = [
'flower', 'luna', 'koko', 'ws', 'docs', 'redocs',
]
urlpatterns = [
path('', views.IndexView.as_view(), name='index'),
path('api/v1/', include(api_v1)),
@@ -63,32 +59,30 @@ urlpatterns = [
# External apps url
path('core/auth/captcha/', include('captcha.urls')),
path('core/', include(app_view_patterns)),
path('ui/', views.UIView.as_view())
path('ui/', views.UIView.as_view()),
]
# 静态文件处理路由
urlpatterns += static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) \
+ static(settings.STATIC_URL, document_root=settings.STATIC_ROOT)
js_i18n_patterns = [
# js i18n 路由文件
urlpatterns += [
path('core/jsi18n/', JavaScriptCatalog.as_view(), name='javascript-catalog'),
]
urlpatterns += js_i18n_patterns
handler404 = 'jumpserver.views.handler404'
handler500 = 'jumpserver.views.handler500'
# docs 路由
urlpatterns += [
re_path('^api/swagger(?P<format>\.json|\.yaml)$',
views.get_swagger_view().without_ui(cache_timeout=1), name='schema-json'),
re_path('api/docs/?', views.get_swagger_view().with_ui('swagger', cache_timeout=1), name="docs"),
re_path('api/redoc/?', views.get_swagger_view().with_ui('redoc', cache_timeout=1), name='redoc'),
if settings.DEBUG:
urlpatterns += [
re_path('^api/swagger(?P<format>\.json|\.yaml)$',
views.get_swagger_view().without_ui(cache_timeout=1), name='schema-json'),
re_path('api/docs/?', views.get_swagger_view().with_ui('swagger', cache_timeout=1), name="docs"),
re_path('api/redoc/?', views.get_swagger_view().with_ui('redoc', cache_timeout=1), name='redoc'),
re_path('^api/v2/swagger(?P<format>\.json|\.yaml)$',
views.get_swagger_view().without_ui(cache_timeout=1), name='schema-json'),
path('api/docs/v2/', views.get_swagger_view("v2").with_ui('swagger', cache_timeout=1), name="docs"),
path('api/redoc/v2/', views.get_swagger_view("v2").with_ui('redoc', cache_timeout=1), name='redoc'),
]
re_path('^api/v2/swagger(?P<format>\.json|\.yaml)$',
views.get_swagger_view().without_ui(cache_timeout=1), name='schema-json'),
path('api/docs/v2/', views.get_swagger_view("v2").with_ui('swagger', cache_timeout=1), name="docs"),
path('api/redoc/v2/', views.get_swagger_view("v2").with_ui('redoc', cache_timeout=1), name='redoc'),
]
# 兼容之前的
@@ -97,3 +91,5 @@ old_app_pattern = r'^{}'.format(old_app_pattern)
urlpatterns += [re_path(old_app_pattern, views.redirect_old_apps_view)]
handler404 = 'jumpserver.views.handler404'
handler500 = 'jumpserver.views.handler500'

View File

@@ -76,9 +76,8 @@ def get_swagger_view(version='v1'):
patterns = api_v1_patterns
schema_view = get_schema_view(
api_info,
public=True,
patterns=patterns,
permission_classes=(permissions.AllowAny,),
permission_classes=(permissions.IsAuthenticated,),
)
return schema_view

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@@ -29,16 +29,3 @@ configs["CELERY_ROUTES"] = {
app.namespace = 'CELERY'
app.conf.update(configs)
app.autodiscover_tasks(lambda: [app_config.split('.')[0] for app_config in settings.INSTALLED_APPS])
app.conf.beat_schedule = {
'check-asset-permission-expired': {
'task': 'perms.tasks.check_asset_permission_expired',
'schedule': settings.PERM_EXPIRED_CHECK_PERIODIC,
'args': ()
},
'check-node-assets-amount': {
'task': 'assets.tasks.nodes_amount.check_node_assets_amount_celery_task',
'schedule': crontab(minute=0, hour=0),
'args': ()
},
}

View File

@@ -3,6 +3,8 @@
import json
import os
import redis_lock
import redis
from django.conf import settings
from django.utils.timezone import get_current_timezone
from django.db.utils import ProgrammingError, OperationalError
@@ -105,3 +107,27 @@ def get_celery_task_log_path(task_id):
path = os.path.join(settings.CELERY_LOG_DIR, rel_path)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
def get_celery_status():
from . import app
i = app.control.inspect()
ping_data = i.ping() or {}
active_nodes = [k for k, v in ping_data.items() if v.get('ok') == 'pong']
active_queue_worker = set([n.split('@')[0] for n in active_nodes if n])
if len(active_queue_worker) < 5:
print("Not all celery worker worked")
return False
else:
return True
def get_beat_status():
CONFIG = settings.CONFIG
r = redis.Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
lock = redis_lock.Lock(r, name="beat-distribute-start-lock")
try:
locked = lock.locked()
return locked
except redis.ConnectionError:
return False

View File

@@ -0,0 +1,20 @@
from django.core.management.base import BaseCommand, CommandError
class Command(BaseCommand):
help = 'Ops manage commands'
def add_arguments(self, parser):
parser.add_argument('check_celery', nargs='?', help='Check celery health')
def handle(self, *args, **options):
from ops.celery.utils import get_celery_status, get_beat_status
ok = get_celery_status()
if not ok:
raise CommandError('Celery worker unhealthy')
ok = get_beat_status()
if not ok:
raise CommandError('Beat unhealthy')

View File

@@ -180,8 +180,10 @@ class AdHoc(OrgModelMixin):
def run(self):
try:
hid = current_task.request.id
if AdHocExecution.objects.filter(id=hid).exists():
hid = uuid.uuid4()
except AttributeError:
hid = str(uuid.uuid4())
hid = uuid.uuid4()
execution = AdHocExecution(
id=hid, adhoc=self, task=self.task,
task_display=str(self.task)[:128],

View File

@@ -128,6 +128,7 @@ class CommandExecutionSerializer(serializers.ModelSerializer):
'result', 'is_finished', 'log_url', 'date_created',
'date_finished'
]
ref_name = 'OpsCommandExecution'
@staticmethod
def get_log_url(obj):

View File

@@ -3,7 +3,7 @@ from django.utils.translation import ugettext_lazy as _
from common.utils import get_logger, get_object_or_none
from common.tasks import send_mail_async
from orgs.utils import tmp_to_org, org_aware_func
from orgs.utils import org_aware_func
from .models import Task, AdHoc

View File

@@ -92,6 +92,7 @@ class OrgMemberAdminRelationBulkViewSet(JMSBulkRelationModelViewSet):
serializer_class = OrgMemberAdminSerializer
filterset_class = OrgMemberRelationFilterSet
search_fields = ('user__name', 'user__username', 'org__name')
lookup_field = 'user_id'
def get_queryset(self):
queryset = super().get_queryset()
@@ -116,6 +117,7 @@ class OrgMemberUserRelationBulkViewSet(JMSBulkRelationModelViewSet):
serializer_class = OrgMemberUserSerializer
filterset_class = OrgMemberRelationFilterSet
search_fields = ('user__name', 'user__username', 'org__name')
lookup_field = 'user_id'
def get_queryset(self):
queryset = super().get_queryset()

View File

@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
#
from .models import Organization
from .utils import get_org_from_request, set_current_org

View File

@@ -8,7 +8,7 @@ from django.core.exceptions import ValidationError
from common.utils import get_logger
from ..utils import (
set_current_org, get_current_org, current_org,
filter_org_queryset
filter_org_queryset, get_org_by_id, get_org_name_by_id
)
from ..models import Organization
@@ -70,13 +70,11 @@ class OrgModelMixin(models.Model):
@property
def org(self):
from orgs.models import Organization
org = Organization.get_instance(self.org_id)
return org
return get_org_by_id(self.org_id)
@property
def org_name(self):
return self.org.name
return get_org_name_by_id(self.org_id)
@property
def fullname(self, attr=None):

View File

@@ -74,26 +74,29 @@ class OrgMemberSerializer(BulkModelSerializer):
).distinct()
class OrgMemberAdminSerializer(BulkModelSerializer):
class OrgMemberOldBaseSerializer(BulkModelSerializer):
organization = serializers.PrimaryKeyRelatedField(
label=_('Organization'), queryset=Organization.objects.all(), required=True, source='org'
)
def to_internal_value(self, data):
view = self.context['view']
org_id = view.kwargs.get('org_id')
if org_id:
data['organization'] = org_id
return super().to_internal_value(data)
class Meta:
model = OrganizationMember
fields = ('id', 'organization', 'user', 'role')
class OrgMemberAdminSerializer(OrgMemberOldBaseSerializer):
role = serializers.HiddenField(default=ROLE.ADMIN)
organization = serializers.PrimaryKeyRelatedField(
label=_('Organization'), queryset=Organization.objects.all(), required=True, source='org'
)
class Meta:
model = OrganizationMember
fields = ('id', 'organization', 'user', 'role')
class OrgMemberUserSerializer(BulkModelSerializer):
class OrgMemberUserSerializer(OrgMemberOldBaseSerializer):
role = serializers.HiddenField(default=ROLE.USER)
organization = serializers.PrimaryKeyRelatedField(
label=_('Organization'), queryset=Organization.objects.all(), required=True, source='org'
)
class Meta:
model = OrganizationMember
fields = ('id', 'organization', 'user', 'role')
class OrgRetrieveSerializer(OrgReadSerializer):

View File

@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
#
from collections import defaultdict
from functools import partial
from django.db.models.signals import m2m_changed
from django.db.models.signals import post_save
@@ -7,10 +9,10 @@ from django.dispatch import receiver
from orgs.utils import tmp_to_org
from .models import Organization, OrganizationMember
from .hands import set_current_org, current_org, Node, get_current_org
from perms.models import (AssetPermission, DatabaseAppPermission,
K8sAppPermission, RemoteAppPermission)
from users.models import UserGroup
from .hands import set_current_org, Node, get_current_org
from perms.models import (AssetPermission, ApplicationPermission)
from users.models import UserGroup, User
from common.const.signals import PRE_REMOVE, POST_REMOVE
@receiver(post_save, sender=Organization)
@@ -34,11 +36,44 @@ def _remove_users(model, users, org):
users = (users, )
m2m_model = model.users.through
if model.users.reverse:
reverse = model.users.reverse
if reverse:
m2m_field_name = model.users.field.m2m_reverse_field_name()
else:
m2m_field_name = model.users.field.m2m_field_name()
m2m_model.objects.filter(**{'user__in': users, f'{m2m_field_name}__org_id': org.id}).delete()
relations = m2m_model.objects.filter(**{
'user__in': users,
f'{m2m_field_name}__org_id': org.id
})
object_id_users_id_map = defaultdict(set)
m2m_field_attr_name = f'{m2m_field_name}_id'
for relation in relations:
object_id = getattr(relation, m2m_field_attr_name)
object_id_users_id_map[object_id].add(relation.user_id)
objects = model.objects.filter(id__in=object_id_users_id_map.keys())
send_m2m_change_signal = partial(
m2m_changed.send,
sender=m2m_model, reverse=reverse, model=User, using=model.objects.db
)
for obj in objects:
send_m2m_change_signal(
instance=obj,
pk_set=object_id_users_id_map[obj.id],
action=PRE_REMOVE
)
relations.delete()
for obj in objects:
send_m2m_change_signal(
instance=obj,
pk_set=object_id_users_id_map[obj.id],
action=POST_REMOVE
)
def _clear_users_from_org(org, users):
@@ -48,8 +83,7 @@ def _clear_users_from_org(org, users):
if not users:
return
models = (AssetPermission, DatabaseAppPermission,
RemoteAppPermission, K8sAppPermission, UserGroup)
models = (AssetPermission, ApplicationPermission, UserGroup)
for m in models:
_remove_users(m, users, org)

View File

@@ -65,6 +65,47 @@ def get_current_org_id():
return org_id
def construct_org_mapper():
orgs = Organization.objects.all()
org_mapper = {str(org.id): org for org in orgs}
default_org = Organization.default()
org_mapper.update({
'': default_org,
Organization.DEFAULT_ID: default_org,
Organization.ROOT_ID: Organization.root(),
Organization.SYSTEM_ID: Organization.system()
})
return org_mapper
def set_org_mapper(org_mapper):
setattr(thread_local, 'org_mapper', org_mapper)
def get_org_mapper():
org_mapper = _find('org_mapper')
if org_mapper is None:
org_mapper = construct_org_mapper()
set_org_mapper(org_mapper)
return org_mapper
def get_org_by_id(org_id):
org_id = str(org_id)
org_mapper = get_org_mapper()
org = org_mapper.get(org_id)
return org
def get_org_name_by_id(org_id):
org = get_org_by_id(org_id)
if org:
org_name = org.name
else:
org_name = 'Not Found'
return org_name
def get_current_org_id_for_serializer():
org_id = get_current_org_id()
if org_id == Organization.DEFAULT_ID:

View File

@@ -1,12 +1,12 @@
# -*- coding: utf-8 -*-
#
from common.permissions import IsOrgAdmin
from orgs.mixins.api import OrgBulkModelViewSet
from applications.models import Application
from perms.models import ApplicationPermission
from perms import serializers
from ..base import BasePermissionViewSet
class ApplicationPermissionViewSet(OrgBulkModelViewSet):
class ApplicationPermissionViewSet(BasePermissionViewSet):
"""
应用授权列表的增删改查API
"""
@@ -14,7 +14,9 @@ class ApplicationPermissionViewSet(OrgBulkModelViewSet):
serializer_class = serializers.ApplicationPermissionSerializer
filter_fields = ['name', 'category', 'type']
search_fields = filter_fields
permission_classes = (IsOrgAdmin,)
custom_filter_fields = BasePermissionViewSet.custom_filter_fields + [
'application_id', 'application'
]
def get_queryset(self):
queryset = super().get_queryset().prefetch_related(
@@ -22,3 +24,22 @@ class ApplicationPermissionViewSet(OrgBulkModelViewSet):
)
return queryset
def filter_application(self, queryset):
application_id = self.request.query_params.get('application_id')
application_name = self.request.query_params.get('application')
if application_id:
applications = Application.objects.filter(pk=application_id)
elif application_name:
applications = Application.objects.filter(name=application_name)
else:
return queryset
if not applications:
return queryset.none()
queryset = queryset.filter(applications=applications)
return queryset
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
queryset = self.filter_application(queryset)
return queryset

View File

@@ -48,7 +48,7 @@ class ApplicationsAsTreeMixin(SerializeApplicationToTreeNodeMixin):
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
data = self.serialize_applications(queryset)
data = self.serialize_applications_with_org(queryset)
return Response(data=data)

View File

@@ -2,14 +2,12 @@
#
from django.db.models import Q
from common.permissions import IsOrgAdmin
from orgs.mixins.api import OrgBulkModelViewSet
from common.utils import get_object_or_none
from perms.models import AssetPermission
from perms.hands import (
User, UserGroup, Asset, Node, SystemUser,
Asset, Node
)
from perms import serializers
from ..base import BasePermissionViewSet
__all__ = [
@@ -17,14 +15,16 @@ __all__ = [
]
class AssetPermissionViewSet(OrgBulkModelViewSet):
class AssetPermissionViewSet(BasePermissionViewSet):
"""
资产授权列表的增删改查api
"""
model = AssetPermission
serializer_class = serializers.AssetPermissionSerializer
filter_fields = ['name']
permission_classes = (IsOrgAdmin,)
custom_filter_fields = BasePermissionViewSet.custom_filter_fields + [
'node_id', 'node', 'asset_id', 'hostname', 'ip'
]
def get_queryset(self):
queryset = super().get_queryset().prefetch_related(
@@ -32,35 +32,6 @@ class AssetPermissionViewSet(OrgBulkModelViewSet):
)
return queryset
def is_query_all(self):
query_all = self.request.query_params.get('all', '1') == '1'
return query_all
def filter_valid(self, queryset):
valid_query = self.request.query_params.get('is_valid', None)
if valid_query is None:
return queryset
invalid = valid_query in ['0', 'N', 'false', 'False']
if invalid:
queryset = queryset.invalid()
else:
queryset = queryset.valid()
return queryset
def filter_system_user(self, queryset):
system_user_id = self.request.query_params.get('system_user_id')
system_user_name = self.request.query_params.get('system_user')
if system_user_id:
system_user = get_object_or_none(SystemUser, pk=system_user_id)
elif system_user_name:
system_user = get_object_or_none(SystemUser, name=system_user_name)
else:
return queryset
if not system_user:
return queryset.none()
queryset = queryset.filter(system_users=system_user)
return queryset
def filter_node(self, queryset):
node_id = self.request.query_params.get('node_id')
node_name = self.request.query_params.get('node')
@@ -112,55 +83,8 @@ class AssetPermissionViewSet(OrgBulkModelViewSet):
).distinct()
return queryset
def filter_user(self, queryset):
user_id = self.request.query_params.get('user_id')
username = self.request.query_params.get('username')
if user_id:
user = get_object_or_none(User, pk=user_id)
elif username:
user = get_object_or_none(User, username=username)
else:
return queryset
if not user:
return queryset.none()
if not self.is_query_all():
queryset = queryset.filter(users=user)
return queryset
groups = user.groups.all()
queryset = queryset.filter(
Q(users=user) | Q(user_groups__in=groups)
).distinct()
return queryset
def filter_user_group(self, queryset):
user_group_id = self.request.query_params.get('user_group_id')
user_group_name = self.request.query_params.get('user_group')
if user_group_id:
group = get_object_or_none(UserGroup, pk=user_group_id)
elif user_group_name:
group = get_object_or_none(UserGroup, name=user_group_name)
else:
return queryset
if not group:
return queryset.none()
queryset = queryset.filter(user_groups=group)
return queryset
def filter_keyword(self, queryset):
keyword = self.request.query_params.get('search')
if not keyword:
return queryset
queryset = queryset.filter(name__icontains=keyword)
return queryset
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
queryset = self.filter_valid(queryset)
queryset = self.filter_user(queryset)
queryset = self.filter_keyword(queryset)
queryset = self.filter_asset(queryset)
queryset = self.filter_node(queryset)
queryset = self.filter_system_user(queryset)
queryset = self.filter_user_group(queryset)
queryset = queryset.distinct()
return queryset

View File

@@ -32,9 +32,6 @@ class UserGroupMixin:
class UserGroupGrantedAssetsApi(ListAPIView):
"""
获取用户组直接授权的资产
"""
permission_classes = (IsOrgAdminOrAppUser,)
serializer_class = serializers.AssetGrantedSerializer
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
@@ -44,11 +41,27 @@ class UserGroupGrantedAssetsApi(ListAPIView):
def get_queryset(self):
user_group_id = self.kwargs.get('pk', '')
return Asset.objects.filter(
Q(granted_by_permissions__user_groups__id=user_group_id)
asset_perms_id = list(AssetPermission.objects.valid().filter(
user_groups__id=user_group_id
).distinct().values_list('id', flat=True))
granted_node_keys = Node.objects.filter(
granted_by_permissions__id__in=asset_perms_id,
).distinct().values_list('key', flat=True)
granted_q = Q()
for _key in granted_node_keys:
granted_q |= Q(nodes__key__startswith=f'{_key}:')
granted_q |= Q(nodes__key=_key)
granted_q |= Q(granted_by_permissions__id__in=asset_perms_id)
assets = Asset.objects.filter(
granted_q
).distinct().only(
*self.only_fields
)
return assets
class UserGroupGrantedNodeAssetsApi(ListAPIView):
@@ -59,6 +72,8 @@ class UserGroupGrantedNodeAssetsApi(ListAPIView):
search_fields = ['hostname', 'ip', 'comment']
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
user_group_id = self.kwargs.get('pk', '')
node_id = self.kwargs.get("node_id")
node = Node.objects.get(id=node_id)
@@ -66,7 +81,7 @@ class UserGroupGrantedNodeAssetsApi(ListAPIView):
granted = AssetPermission.objects.filter(
user_groups__id=user_group_id,
nodes__id=node_id
).exists()
).valid().exists()
if granted:
assets = Asset.objects.filter(
Q(nodes__key__startswith=f'{node.key}:') |
@@ -74,8 +89,12 @@ class UserGroupGrantedNodeAssetsApi(ListAPIView):
)
return assets
else:
asset_perms_id = list(AssetPermission.objects.valid().filter(
user_groups__id=user_group_id
).distinct().values_list('id', flat=True))
granted_node_keys = Node.objects.filter(
granted_by_permissions__user_groups__id=user_group_id,
granted_by_permissions__id__in=asset_perms_id,
key__startswith=f'{node.key}:'
).distinct().values_list('key', flat=True)
@@ -85,7 +104,7 @@ class UserGroupGrantedNodeAssetsApi(ListAPIView):
granted_node_q |= Q(nodes__key=_key)
granted_asset_q = (
Q(granted_by_permissions__user_groups__id=user_group_id) &
Q(granted_by_permissions__id__in=asset_perms_id) &
(
Q(nodes__key__startswith=f'{node.key}:') |
Q(nodes__key=node.key)
@@ -129,12 +148,16 @@ class UserGroupGrantedNodeChildrenAsTreeApi(SerializeToTreeNodeMixin, ListAPIVie
group_id = self.kwargs.get('pk')
node_key = self.request.query_params.get('key', None)
asset_perms_id = list(AssetPermission.objects.valid().filter(
user_groups__id=group_id
).distinct().values_list('id', flat=True))
granted_keys = Node.objects.filter(
granted_by_permissions__user_groups__id=group_id
granted_by_permissions__id__in=asset_perms_id
).values_list('key', flat=True)
asset_granted_keys = Node.objects.filter(
assets__granted_by_permissions__user_groups__id=group_id
assets__granted_by_permissions__id__in=asset_perms_id
).values_list('key', flat=True)
if node_key is None:

View File

@@ -3,6 +3,7 @@
from perms.api.asset.user_permission.mixin import UserNodeGrantStatusDispatchMixin
from rest_framework.generics import ListAPIView
from rest_framework.response import Response
from rest_framework.request import Request
from django.conf import settings
from assets.api.mixin import SerializeToTreeNodeMixin
@@ -30,6 +31,8 @@ class UserDirectGrantedAssetsApi(ListAPIView):
search_fields = ['hostname', 'ip', 'comment']
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
user = self.user
assets = get_user_direct_granted_assets(user)\
.prefetch_related('platform')\
@@ -44,6 +47,8 @@ class UserFavoriteGrantedAssetsApi(ListAPIView):
search_fields = ['hostname', 'ip', 'comment']
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
user = self.user
assets = FavoriteAsset.get_user_favorite_assets(user)\
.prefetch_related('platform')\
@@ -55,8 +60,12 @@ class AssetsAsTreeMixin(SerializeToTreeNodeMixin):
"""
将 资产 序列化成树的结构返回
"""
def list(self, request, *args, **kwargs):
def list(self, request: Request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
if request.query_params.get('search'):
# 如果用户搜索的条件不精准,会导致返回大量的无意义数据。
# 这里限制一下返回数据的最大条数
queryset = queryset[:999]
data = self.serialize_assets(queryset, None)
return Response(data=data)
@@ -96,6 +105,8 @@ class UserAllGrantedAssetsApi(ForAdminMixin, ListAPIView):
search_fields = ['hostname', 'ip', 'comment']
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
queryset = get_user_granted_all_assets(self.user)
queryset = queryset.prefetch_related('platform')
return queryset.only(*self.only_fields)
@@ -118,6 +129,8 @@ class UserGrantedNodeAssetsApi(UserNodeGrantStatusDispatchMixin, ListAPIView):
pagination_node: Node
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
node_id = self.kwargs.get("node_id")
node = Node.objects.get(id=node_id)
self.pagination_node = node

View File

@@ -139,11 +139,13 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
return Response(data=data)
class UserGrantedNodeChildrenWithAssetsAsTreeForAdminApi(ForAdminMixin, UserNodeGrantStatusDispatchMixin,
SerializeToTreeNodeMixin, ListAPIView):
class GrantedNodeChildrenWithAssetsAsTreeApiMixin(UserNodeGrantStatusDispatchMixin,
SerializeToTreeNodeMixin,
ListAPIView):
"""
带资产的授权树
"""
user: None
def get_data_on_node_direct_granted(self, key):
nodes = Node.objects.filter(parent_key=key)
@@ -203,5 +205,9 @@ class UserGrantedNodeChildrenWithAssetsAsTreeForAdminApi(ForAdminMixin, UserNode
return Response(data=[*tree_nodes, *tree_assets])
class MyGrantedNodeChildrenWithAssetsAsTreeApi(ForUserMixin, UserGrantedNodeChildrenWithAssetsAsTreeForAdminApi):
class UserGrantedNodeChildrenWithAssetsAsTreeApi(ForAdminMixin, GrantedNodeChildrenWithAssetsAsTreeApiMixin):
pass
class MyGrantedNodeChildrenWithAssetsAsTreeApi(ForUserMixin, GrantedNodeChildrenWithAssetsAsTreeApiMixin):
pass

View File

@@ -1,13 +1,106 @@
from django.db.models import F
from orgs.mixins.api import OrgBulkModelViewSet
from orgs.mixins.api import OrgRelationMixin
from django.db.models import Q
from common.permissions import IsOrgAdmin
from common.utils import get_object_or_none
from orgs.mixins.api import OrgBulkModelViewSet
from assets.models import SystemUser
from users.models import User, UserGroup
__all__ = [
'RelationViewSet'
'RelationViewSet', 'BasePermissionViewSet'
]
class BasePermissionViewSet(OrgBulkModelViewSet):
custom_filter_fields = [
'user_id', 'username', 'system_user_id', 'system_user',
'user_group_id', 'user_group'
]
permission_classes = (IsOrgAdmin,)
def filter_valid(self, queryset):
valid_query = self.request.query_params.get('is_valid', None)
if valid_query is None:
return queryset
invalid = valid_query in ['0', 'N', 'false', 'False']
if invalid:
queryset = queryset.invalid()
else:
queryset = queryset.valid()
return queryset
def is_query_all(self):
query_all = self.request.query_params.get('all', '1') == '1'
return query_all
def filter_user(self, queryset):
user_id = self.request.query_params.get('user_id')
username = self.request.query_params.get('username')
if user_id:
user = get_object_or_none(User, pk=user_id)
elif username:
user = get_object_or_none(User, username=username)
else:
return queryset
if not user:
return queryset.none()
if not self.is_query_all():
queryset = queryset.filter(users=user)
return queryset
groups = user.groups.all()
queryset = queryset.filter(
Q(users=user) | Q(user_groups__in=groups)
).distinct()
return queryset
def filter_keyword(self, queryset):
keyword = self.request.query_params.get('search')
if not keyword:
return queryset
queryset = queryset.filter(name__icontains=keyword)
return queryset
def filter_system_user(self, queryset):
system_user_id = self.request.query_params.get('system_user_id')
system_user_name = self.request.query_params.get('system_user')
if system_user_id:
system_user = get_object_or_none(SystemUser, pk=system_user_id)
elif system_user_name:
system_user = get_object_or_none(SystemUser, name=system_user_name)
else:
return queryset
if not system_user:
return queryset.none()
queryset = queryset.filter(system_users=system_user)
return queryset
def filter_user_group(self, queryset):
user_group_id = self.request.query_params.get('user_group_id')
user_group_name = self.request.query_params.get('user_group')
if user_group_id:
group = get_object_or_none(UserGroup, pk=user_group_id)
elif user_group_name:
group = get_object_or_none(UserGroup, name=user_group_name)
else:
return queryset
if not group:
return queryset.none()
queryset = queryset.filter(user_groups=group)
return queryset
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
queryset = self.filter_valid(queryset)
queryset = self.filter_user(queryset)
queryset = self.filter_system_user(queryset)
queryset = self.filter_user_group(queryset)
queryset = self.filter_keyword(queryset)
queryset = queryset.distinct()
return queryset
class RelationViewSet(OrgRelationMixin, OrgBulkModelViewSet):
def get_queryset(self):
queryset = super().get_queryset()

View File

@@ -62,6 +62,8 @@ class Action:
@classmethod
def value_to_choices(cls, value):
if isinstance(value, list):
return value
value = int(value)
choices = [cls.NAME_MAP[i] for i, j in cls.DB_CHOICES if value & i == i]
return choices

View File

@@ -16,3 +16,4 @@ class SystemUserSerializer(serializers.ModelSerializer):
'auto_push', 'cmd_filters', 'sudo', 'shell', 'comment',
'sftp_root', 'date_created', 'created_by'
]
ref_name = 'PermedSystemUserSerializer'

View File

@@ -6,16 +6,53 @@ from django.dispatch import receiver
from perms.tasks import create_rebuild_user_tree_task, \
create_rebuild_user_tree_task_by_related_nodes_or_assets
from users.models import User, UserGroup
from assets.models import Asset
from assets.models import Asset, SystemUser
from applications.models import Application, Category
from common.utils import get_logger
from common.exceptions import M2MReverseNotAllowed
from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR
from .models import AssetPermission, RemoteAppPermission
from .models import AssetPermission, RemoteAppPermission, ApplicationPermission
logger = get_logger(__file__)
def handle_rebuild_user_tree(instance, action, reverse, pk_set, **kwargs):
if action.startswith('post'):
if reverse:
create_rebuild_user_tree_task(pk_set)
else:
create_rebuild_user_tree_task([instance.id])
def handle_bind_groups_systemuser(instance, action, reverse, pk_set, **kwargs):
"""
UserGroup 增加 User 时,增加的 User 需要与 UserGroup 关联的动态系统用户相关联
"""
user: User
if action != POST_ADD:
return
if not reverse:
# 一个用户添加了多个用户组
users_id = [instance.id]
system_users = SystemUser.objects.filter(groups__id__in=pk_set).distinct()
else:
# 一个用户组添加了多个用户
users_id = pk_set
system_users = SystemUser.objects.filter(groups__id=instance.pk).distinct()
for system_user in system_users:
system_user.users.add(*users_id)
@receiver(m2m_changed, sender=User.groups.through)
def on_user_groups_change(**kwargs):
handle_rebuild_user_tree(**kwargs)
handle_bind_groups_systemuser(**kwargs)
@receiver([pre_save], sender=AssetPermission)
def on_asset_perm_deactive(instance: AssetPermission, **kwargs):
try:
@@ -208,3 +245,93 @@ def on_node_asset_change(action, instance, reverse, pk_set, **kwargs):
node_pk_set = pk_set
create_rebuild_user_tree_task_by_related_nodes_or_assets.delay(node_pk_set, asset_pk_set)
@receiver(m2m_changed, sender=ApplicationPermission.system_users.through)
def on_application_permission_system_users_changed(sender, instance: ApplicationPermission, action, reverse, pk_set, **kwargs):
if instance.category != Category.remote_app:
return
if reverse:
raise M2MReverseNotAllowed
if action != POST_ADD:
return
system_users = SystemUser.objects.filter(pk__in=pk_set)
logger.debug("Application permission system_users change signal received")
attrs = instance.applications.all().values_list('attrs', flat=True)
assets_id = [attr['asset'] for attr in attrs if attr.get('asset')]
if not assets_id:
return
for system_user in system_users:
system_user.assets.add(*assets_id)
if system_user.username_same_with_user:
users_id = instance.users.all().values_list('id', flat=True)
groups_id = instance.user_groups.all().values_list('id', flat=True)
system_user.groups.add(*groups_id)
system_user.users.add(*users_id)
@receiver(m2m_changed, sender=ApplicationPermission.users.through)
def on_application_permission_users_changed(sender, instance, action, reverse, pk_set, **kwargs):
if instance.category != Category.remote_app:
return
if reverse:
raise M2MReverseNotAllowed
if action != POST_ADD:
return
logger.debug("Application permission users change signal received")
users_id = User.objects.filter(pk__in=pk_set).values_list('id', flat=True)
system_users = instance.system_users.all()
for system_user in system_users:
if system_user.username_same_with_user:
system_user.users.add(*users_id)
@receiver(m2m_changed, sender=ApplicationPermission.user_groups.through)
def on_application_permission_user_groups_changed(sender, instance, action, reverse, pk_set, **kwargs):
if instance.category != Category.remote_app:
return
if reverse:
raise M2MReverseNotAllowed
if action != POST_ADD:
return
logger.debug("Application permission user groups change signal received")
groups_id = UserGroup.objects.filter(pk__in=pk_set).values_list('id', flat=True)
system_users = instance.system_users.all()
for system_user in system_users:
if system_user.username_same_with_user:
system_user.groups.add(*groups_id)
@receiver(m2m_changed, sender=ApplicationPermission.applications.through)
def on_application_permission_applications_changed(sender, instance, action, reverse, pk_set, **kwargs):
if instance.category != Category.remote_app:
return
if reverse:
raise M2MReverseNotAllowed
if action != POST_ADD:
return
attrs = Application.objects.filter(id__in=pk_set).values_list('attrs', flat=True)
assets_id = [attr['asset'] for attr in attrs if attr.get('asset')]
if not assets_id:
return
system_users = instance.system_users.all()
for system_user in system_users:
system_user.assets.add(*assets_id)

View File

@@ -5,10 +5,12 @@ from datetime import timedelta
from django.db import transaction
from django.db.models import Q
from django.db.transaction import atomic
from django.conf import settings
from celery import shared_task
from common.utils import get_logger
from common.utils.timezone import now, dt_formater, dt_parser
from users.models import User
from ops.celery.decorator import register_as_period_task
from assets.models import Node
from perms.models import RebuildUserTreeTask, AssetPermission
from perms.utils.asset.user_permission import rebuild_user_mapping_nodes_if_need_with_lock, lock
@@ -33,7 +35,8 @@ def dispatch_mapping_node_tasks():
rebuild_user_mapping_nodes_celery_task.delay(id)
@shared_task(queue='check_asset_perm_expired')
@register_as_period_task(interval=settings.PERM_EXPIRED_CHECK_PERIODIC)
@shared_task(queue='celery_check_asset_perm_expired')
@atomic()
def check_asset_permission_expired():
"""

View File

@@ -21,7 +21,7 @@ user_permission_urlpatterns = [
# ---------------------------------------------------------
# 以 serializer 格式返回
path('<uuid:pk>/assets/', api.UserAllGrantedAssetsApi.as_view(), name='user-assets'),
path('assets/', api.MyAllAssetsAsTreeApi.as_view(), name='my-assets'),
path('assets/', api.MyAllGrantedAssetsApi.as_view(), name='my-assets'),
# Tree Node 的数据格式返回
path('<uuid:pk>/assets/tree/', api.UserDirectGrantedAssetsAsTreeForAdminApi.as_view(), name='user-assets-as-tree'),
@@ -56,7 +56,7 @@ user_permission_urlpatterns = [
path('nodes-with-assets/tree/', api.MyGrantedNodesWithAssetsAsTreeApi.as_view(), name='my-nodes-with-assets-as-tree'),
# 主要用于 luna 页面,带资产的节点树
path('<uuid:pk>/nodes/children-with-assets/tree/', api.UserGrantedNodeChildrenWithAssetsAsTreeForAdminApi.as_view(), name='user-nodes-children-with-assets-as-tree'),
path('<uuid:pk>/nodes/children-with-assets/tree/', api.UserGrantedNodeChildrenWithAssetsAsTreeApi.as_view(), name='user-nodes-children-with-assets-as-tree'),
path('nodes/children-with-assets/tree/', api.MyGrantedNodeChildrenWithAssetsAsTreeApi.as_view(), name='my-nodes-children-with-assets-as-tree'),
# 查询授权树上某个节点的所有资产

View File

@@ -34,27 +34,6 @@ TMP_ASSET_GRANTED_FIELD = '_asset_granted'
TMP_GRANTED_ASSETS_AMOUNT_FIELD = '_granted_assets_amount'
# 使用场景
# Asset.objects.filter(get_user_resources_q_granted_by_permissions(user))
def get_user_resources_q_granted_by_permissions(user: User):
"""
获取用户关联的 asset permission 或者 用户组关联的 asset permission 获取规则,
前提 AssetPermission 对象中的 related_name 为 granted_by_permissions
:param user:
:return:
"""
_now = now()
return reduce(and_, (
Q(granted_by_permissions__date_start__lt=_now),
Q(granted_by_permissions__date_expired__gt=_now),
Q(granted_by_permissions__is_active=True),
(
Q(granted_by_permissions__users=user) |
Q(granted_by_permissions__user_groups__users=user)
)
))
# 使用场景
# `Node.objects.annotate(**node_annotate_mapping_node)`
node_annotate_mapping_node = {
@@ -215,7 +194,7 @@ def compute_tmp_mapping_node_from_perm(user: User, asset_perms_id=None):
return [*leaf_nodes, *ancestors]
def create_mapping_nodes(user, nodes, clear=True):
def create_mapping_nodes(user, nodes):
to_create = []
for node in nodes:
_granted = getattr(node, TMP_GRANTED_FIELD, False)
@@ -231,8 +210,6 @@ def create_mapping_nodes(user, nodes, clear=True):
assets_amount=_granted_assets_amount,
))
if clear:
UserGrantedMappingNode.objects.filter(user=user).delete()
UserGrantedMappingNode.objects.bulk_create(to_create)
@@ -254,6 +231,9 @@ def set_node_granted_assets_amount(user, node, asset_perms_id=None):
@tmp_to_root_org()
def rebuild_user_mapping_nodes(user):
logger.info(f'>>> {dt_formater(now())} start rebuild {user} mapping nodes')
# 先删除旧的授权树🌲
UserGrantedMappingNode.objects.filter(user=user).delete()
asset_perms_id = get_user_all_assetpermissions_id(user)
if not asset_perms_id:
# 没有授权直接返回
@@ -384,7 +364,8 @@ def get_node_all_granted_assets(user: User, key):
if only_asset_granted_nodes_qs:
only_asset_granted_nodes_q = reduce(or_, only_asset_granted_nodes_qs)
only_asset_granted_nodes_q &= get_user_resources_q_granted_by_permissions(user)
asset_perms_id = get_user_all_assetpermissions_id(user)
only_asset_granted_nodes_q &= Q(granted_by_permissions__id__in=list(asset_perms_id))
q.append(only_asset_granted_nodes_q)
if q:
@@ -484,6 +465,9 @@ def get_user_all_assetpermissions_id(user: User):
asset_perms_id = AssetPermission.objects.valid().filter(
Q(users=user) | Q(user_groups__users=user)
).distinct().values_list('id', flat=True)
# !!! 这个很重要,必须转换成 list避免 Django 生成嵌套子查询
asset_perms_id = list(asset_perms_id)
return asset_perms_id

View File

@@ -145,9 +145,22 @@ class LDAPServerUtil(object):
paged_cookie=paged_cookie
)
@staticmethod
def distinct_user_entries(user_entries):
distinct_user_entries = list()
distinct_user_entries_dn = set()
for user_entry in user_entries:
if user_entry.entry_dn in distinct_user_entries_dn:
continue
distinct_user_entries_dn.add(user_entry.entry_dn)
distinct_user_entries.append(user_entry)
return distinct_user_entries
@timeit
def search_user_entries(self):
def search_user_entries(self, search_users=None, search_value=None):
logger.info("Search user entries")
self.search_users = search_users
self.search_value = search_value
user_entries = list()
search_ous = str(self.config.search_ou).split('|')
for search_ou in search_ous:
@@ -157,6 +170,7 @@ class LDAPServerUtil(object):
while self.paged_cookie():
self.search_user_entries_ou(search_ou, self.paged_cookie())
user_entries.extend(self.connection.entries)
user_entries = self.distinct_user_entries(user_entries)
return user_entries
def user_entry_to_dict(self, entry):
@@ -172,7 +186,6 @@ class LDAPServerUtil(object):
user[attr] = value
return user
@timeit
def user_entries_to_dict(self, user_entries):
users = []
for user_entry in user_entries:
@@ -180,12 +193,21 @@ class LDAPServerUtil(object):
users.append(user)
return users
def search_for_user_dn(self, username):
user_entries = self.search_user_entries(search_users=[username])
if len(user_entries) == 1:
user_entry = user_entries[0]
user_dn = user_entry.entry_dn
else:
user_dn = None
return user_dn
@timeit
def search(self, search_users=None, search_value=None):
logger.info("Search ldap users")
self.search_users = search_users
self.search_value = search_value
user_entries = self.search_user_entries()
user_entries = self.search_user_entries(
search_users=search_users, search_value=search_value
)
users = self.user_entries_to_dict(user_entries)
return users
@@ -333,7 +355,7 @@ class LDAPImportUtil(object):
def update_or_create(self, user):
user['email'] = self.get_user_email(user)
if user['username'] not in ['admin']:
user['source'] = User.SOURCE_LDAP
user['source'] = User.Source.ldap.value
obj, created = User.objects.update_or_create(
username=user['username'], defaults=user
)

View File

@@ -1,12 +0,0 @@
{{image}}{{hidden_field}}{{text_field}}
<script>
function refresh_captcha() {
$.getJSON("{% url "captcha-refresh" %}",
function (result) {
$('.captcha').attr('src', result['image_url']);
$('#id_captcha_0').val(result['key'])
})
}
$('.captcha').click(refresh_captcha)
</script>

View File

@@ -1 +0,0 @@
<input id="{{id}}_0" name="{{name}}_0" type="hidden" value="{{key}}" />

View File

@@ -1,4 +0,0 @@
{% load i18n %}
{% spaceless %}
{% if audio %}<a title="{% trans "Play CAPTCHA as audio file" %}" href="{{audio}}">{% endif %}<img src="{{image}}" alt="captcha" class="captcha" />{% if audio %}</a>{% endif %}
{% endspaceless %}

View File

@@ -1,7 +0,0 @@
{% load i18n %}
<div class="row">
<div class="col-sm-6">
<input autocomplete="off" id="{{id}}_1" class="form-control" name="{{name}}_1" placeholder="{% trans 'Captcha' %}" type="text" />
</div>
</div>
</br>

View File

@@ -5,3 +5,4 @@ from .session import *
from .command import *
from .task import *
from .storage import *
from .component import *

View File

@@ -60,7 +60,7 @@ class CommandQueryMixin:
queryset = multi_command_storage.filter(
date_from=date_from, date_to=date_to,
user=q.get("user"), asset=q.get("asset"), system_user=q.get("system_user"),
input=q.get("input"), session=q.get("session_id"),
input=q.get("input"), session=q.get("session_id", q.get('session')),
risk_level=self.get_query_risk_level(), org_id=self.get_org_id(),
)
return queryset

View File

@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
#
import logging
from rest_framework import generics, status
from rest_framework.views import Response
from .. import serializers
from ..utils import ComponentsMetricsUtil
from common.permissions import IsAppUser, IsSuperUser
logger = logging.getLogger(__file__)
__all__ = [
'ComponentsStateAPIView', 'ComponentsMetricsAPIView',
]
class ComponentsStateAPIView(generics.CreateAPIView):
""" koko, guacamole, omnidb 上报状态 """
permission_classes = (IsAppUser,)
serializer_class = serializers.ComponentsStateSerializer
class ComponentsMetricsAPIView(generics.GenericAPIView):
""" 返回汇总组件指标数据 """
permission_classes = (IsSuperUser,)
def get(self, request, *args, **kwargs):
tp = request.query_params.get('type')
util = ComponentsMetricsUtil()
metrics = util.get_metrics(tp)
return Response(metrics, status=status.HTTP_200_OK)

View File

@@ -27,7 +27,7 @@ class TerminalViewSet(JMSBulkModelViewSet):
queryset = Terminal.objects.filter(is_deleted=False)
serializer_class = serializers.TerminalSerializer
permission_classes = (IsSuperUser,)
filter_fields = ['name', 'remote_addr']
filter_fields = ['name', 'remote_addr', 'type']
def create(self, request, *args, **kwargs):
if isinstance(request.data, list):
@@ -60,6 +60,15 @@ class TerminalViewSet(JMSBulkModelViewSet):
logger.error("Register terminal error: {}".format(data))
return Response(data, status=400)
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
status = self.request.query_params.get('status')
if not status:
return queryset
filtered_queryset_id = [str(q.id) for q in queryset if q.status == status]
queryset = queryset.filter(id__in=filtered_queryset_id)
return queryset
def get_permissions(self):
if self.action == "create":
self.permission_classes = (AllowAny,)
@@ -104,15 +113,11 @@ class StatusViewSet(viewsets.ModelViewSet):
task_serializer_class = serializers.TaskSerializer
def create(self, request, *args, **kwargs):
self.handle_status(request)
self.handle_sessions()
tasks = self.request.user.terminal.task_set.filter(is_finished=False)
serializer = self.task_serializer_class(tasks, many=True)
return Response(serializer.data, status=201)
def handle_status(self, request):
request.user.terminal.is_alive = True
def handle_sessions(self):
sessions_id = self.request.data.get('sessions', [])
# guacamole 上报的 session 是字符串

View File

@@ -14,7 +14,7 @@ class SessionCommandSerializer(serializers.Serializer):
system_user = serializers.CharField(max_length=64, label=_("System user"))
input = serializers.CharField(max_length=128, label=_("Command"))
output = serializers.CharField(max_length=1024, allow_blank=True, label=_("Output"))
session = serializers.CharField(max_length=36, label=_("Session"))
session = serializers.CharField(max_length=36, label=_("Session ID"))
risk_level = serializers.ChoiceField(required=False, label=_("Risk level"), choices=AbstractSessionCommand.RISK_LEVEL_CHOICES)
risk_level_display = serializers.SerializerMethodField(label=_('Risk level for display'))
org_id = serializers.CharField(max_length=36, required=False, default='', allow_null=True, allow_blank=True)

View File

@@ -108,3 +108,27 @@ COMMAND_STORAGE_TYPE_CHOICES_EXTENDS = [
COMMAND_STORAGE_TYPE_CHOICES = COMMAND_STORAGE_TYPE_CHOICES_DEFAULT + \
COMMAND_STORAGE_TYPE_CHOICES_EXTENDS
from django.db.models import TextChoices
from django.utils.translation import ugettext_lazy as _
class ComponentStatusChoices(TextChoices):
critical = 'critical', _('Critical')
high = 'high', _('High')
normal = 'normal', _('Normal')
@classmethod
def status(cls):
return set(dict(cls.choices).keys())
class TerminalTypeChoices(TextChoices):
koko = 'koko', 'KoKo'
guacamole = 'guacamole', 'Guacamole'
omnidb = 'omnidb', 'OmniDB'
@classmethod
def types(cls):
return set(dict(cls.choices).keys())

View File

@@ -0,0 +1,41 @@
# Generated by Django 3.1 on 2020-12-15 04:52
from django.db import migrations, models
TERMINAL_TYPE_KOKO = 'koko'
TERMINAL_TYPE_GUACAMOLE = 'guacamole'
TERMINAL_TYPE_OMNIDB = 'omnidb'
def migrate_terminal_type(apps, schema_editor):
terminal_model = apps.get_model("terminal", "Terminal")
db_alias = schema_editor.connection.alias
terminals = terminal_model.objects.using(db_alias).all()
for terminal in terminals:
name = terminal.name.lower()
if 'koko' in name:
_type = TERMINAL_TYPE_KOKO
elif 'gua' in name:
_type = TERMINAL_TYPE_GUACAMOLE
elif 'omnidb' in name:
_type = TERMINAL_TYPE_OMNIDB
else:
_type = TERMINAL_TYPE_KOKO
terminal.type = _type
terminal_model.objects.bulk_update(terminals, ['type'])
class Migration(migrations.Migration):
dependencies = [
('terminal', '0029_auto_20201116_1757'),
]
operations = [
migrations.AddField(
model_name='terminal',
name='type',
field=models.CharField(choices=[('koko', 'KoKo'), ('guacamole', 'Guacamole'), ('omnidb', 'OmniDB')], default='koko', max_length=64, verbose_name='type'),
),
migrations.RunPython(migrate_terminal_type)
]

View File

@@ -1,486 +0,0 @@
from __future__ import unicode_literals
import os
import uuid
import jms_storage
from django.db import models
from django.db.models.signals import post_save
from django.utils.translation import ugettext_lazy as _
from django.utils import timezone
from django.conf import settings
from django.core.files.storage import default_storage
from django.core.cache import cache
from assets.models import Asset
from users.models import User
from orgs.mixins.models import OrgModelMixin
from common.mixins import CommonModelMixin
from common.fields.model import EncryptJsonDictTextField
from common.db.models import ChoiceSet
from .backends import get_multi_command_storage
from .backends.command.models import AbstractSessionCommand
from . import const
class Terminal(models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=128, verbose_name=_('Name'))
remote_addr = models.CharField(max_length=128, blank=True, verbose_name=_('Remote Address'))
ssh_port = models.IntegerField(verbose_name=_('SSH Port'), default=2222)
http_port = models.IntegerField(verbose_name=_('HTTP Port'), default=5000)
command_storage = models.CharField(max_length=128, verbose_name=_("Command storage"), default='default')
replay_storage = models.CharField(max_length=128, verbose_name=_("Replay storage"), default='default')
user = models.OneToOneField(User, related_name='terminal', verbose_name='Application User', null=True, on_delete=models.CASCADE)
is_accepted = models.BooleanField(default=False, verbose_name='Is Accepted')
is_deleted = models.BooleanField(default=False)
date_created = models.DateTimeField(auto_now_add=True)
comment = models.TextField(blank=True, verbose_name=_('Comment'))
STATUS_KEY_PREFIX = 'terminal_status_'
@property
def is_alive(self):
key = self.STATUS_KEY_PREFIX + str(self.id)
return bool(cache.get(key))
@is_alive.setter
def is_alive(self, value):
key = self.STATUS_KEY_PREFIX + str(self.id)
cache.set(key, value, 60)
@property
def is_active(self):
if self.user and self.user.is_active:
return True
return False
@is_active.setter
def is_active(self, active):
if self.user:
self.user.is_active = active
self.user.save()
def get_command_storage(self):
storage = CommandStorage.objects.filter(name=self.command_storage).first()
return storage
def get_command_storage_config(self):
s = self.get_command_storage()
if s:
config = s.config
else:
config = settings.DEFAULT_TERMINAL_COMMAND_STORAGE
return config
def get_command_storage_setting(self):
config = self.get_command_storage_config()
return {"TERMINAL_COMMAND_STORAGE": config}
def get_replay_storage(self):
storage = ReplayStorage.objects.filter(name=self.replay_storage).first()
return storage
def get_replay_storage_config(self):
s = self.get_replay_storage()
if s:
config = s.config
else:
config = settings.DEFAULT_TERMINAL_REPLAY_STORAGE
return config
def get_replay_storage_setting(self):
config = self.get_replay_storage_config()
return {"TERMINAL_REPLAY_STORAGE": config}
@staticmethod
def get_login_title_setting():
login_title = None
if settings.XPACK_ENABLED:
from xpack.plugins.interface.models import Interface
login_title = Interface.get_login_title()
return {'TERMINAL_HEADER_TITLE': login_title}
@property
def config(self):
configs = {}
for k in dir(settings):
if not k.startswith('TERMINAL'):
continue
configs[k] = getattr(settings, k)
configs.update(self.get_command_storage_setting())
configs.update(self.get_replay_storage_setting())
configs.update(self.get_login_title_setting())
configs.update({
'SECURITY_MAX_IDLE_TIME': settings.SECURITY_MAX_IDLE_TIME
})
return configs
@property
def service_account(self):
return self.user
def create_app_user(self):
random = uuid.uuid4().hex[:6]
user, access_key = User.create_app_user(
name="{}-{}".format(self.name, random), comment=self.comment
)
self.user = user
self.save()
return user, access_key
def delete(self, using=None, keep_parents=False):
if self.user:
self.user.delete()
self.user = None
self.is_deleted = True
self.save()
return
def __str__(self):
status = "Active"
if not self.is_accepted:
status = "NotAccept"
elif self.is_deleted:
status = "Deleted"
elif not self.is_active:
status = "Disable"
return '%s: %s' % (self.name, status)
class Meta:
ordering = ('is_accepted',)
db_table = "terminal"
class Status(models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
session_online = models.IntegerField(verbose_name=_("Session Online"), default=0)
cpu_used = models.FloatField(verbose_name=_("CPU Usage"))
memory_used = models.FloatField(verbose_name=_("Memory Used"))
connections = models.IntegerField(verbose_name=_("Connections"))
threads = models.IntegerField(verbose_name=_("Threads"))
boot_time = models.FloatField(verbose_name=_("Boot Time"))
terminal = models.ForeignKey(Terminal, null=True, on_delete=models.CASCADE)
date_created = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = 'terminal_status'
get_latest_by = 'date_created'
def __str__(self):
return self.date_created.strftime("%Y-%m-%d %H:%M:%S")
class Session(OrgModelMixin):
class LOGIN_FROM(ChoiceSet):
ST = 'ST', 'SSH Terminal'
WT = 'WT', 'Web Terminal'
class PROTOCOL(ChoiceSet):
SSH = 'ssh', 'ssh'
RDP = 'rdp', 'rdp'
VNC = 'vnc', 'vnc'
TELNET = 'telnet', 'telnet'
MYSQL = 'mysql', 'mysql'
ORACLE = 'oracle', 'oracle'
MARIADB = 'mariadb', 'mariadb'
POSTGRESQL = 'postgresql', 'postgresql'
K8S = 'k8s', 'kubernetes'
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
user = models.CharField(max_length=128, verbose_name=_("User"), db_index=True)
user_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
asset = models.CharField(max_length=128, verbose_name=_("Asset"), db_index=True)
asset_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
system_user = models.CharField(max_length=128, verbose_name=_("System user"), db_index=True)
system_user_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
login_from = models.CharField(max_length=2, choices=LOGIN_FROM.choices, default="ST", verbose_name=_("Login from"))
remote_addr = models.CharField(max_length=128, verbose_name=_("Remote addr"), blank=True, null=True)
is_success = models.BooleanField(default=True, db_index=True)
is_finished = models.BooleanField(default=False, db_index=True)
has_replay = models.BooleanField(default=False, verbose_name=_("Replay"))
has_command = models.BooleanField(default=False, verbose_name=_("Command"))
terminal = models.ForeignKey(Terminal, null=True, on_delete=models.DO_NOTHING, db_constraint=False)
protocol = models.CharField(choices=PROTOCOL.choices, default='ssh', max_length=16, db_index=True)
date_start = models.DateTimeField(verbose_name=_("Date start"), db_index=True, default=timezone.now)
date_end = models.DateTimeField(verbose_name=_("Date end"), null=True)
upload_to = 'replay'
ACTIVE_CACHE_KEY_PREFIX = 'SESSION_ACTIVE_{}'
_DATE_START_FIRST_HAS_REPLAY_RDP_SESSION = None
def get_rel_replay_path(self, version=2):
"""
获取session日志的文件路径
:param version: 原来后缀是 .gz为了统一新版本改为 .replay.gz
:return:
"""
suffix = '.replay.gz'
if version == 1:
suffix = '.gz'
date = self.date_start.strftime('%Y-%m-%d')
return os.path.join(date, str(self.id) + suffix)
def get_local_path(self, version=2):
rel_path = self.get_rel_replay_path(version=version)
if version == 2:
local_path = os.path.join(self.upload_to, rel_path)
else:
local_path = rel_path
return local_path
@property
def asset_obj(self):
return Asset.objects.get(id=self.asset_id)
@property
def _date_start_first_has_replay_rdp_session(self):
if self.__class__._DATE_START_FIRST_HAS_REPLAY_RDP_SESSION is None:
instance = self.__class__.objects.filter(
protocol='rdp', has_replay=True
).order_by('date_start').first()
if not instance:
date_start = timezone.now() - timezone.timedelta(days=365)
else:
date_start = instance.date_start
self.__class__._DATE_START_FIRST_HAS_REPLAY_RDP_SESSION = date_start
return self.__class__._DATE_START_FIRST_HAS_REPLAY_RDP_SESSION
def can_replay(self):
if self.has_replay:
return True
if self.date_start < self._date_start_first_has_replay_rdp_session:
return True
return False
@property
def can_join(self):
_PROTOCOL = self.PROTOCOL
if self.is_finished:
return False
if self.protocol in [_PROTOCOL.SSH, _PROTOCOL.TELNET, _PROTOCOL.K8S]:
return True
else:
return False
@property
def db_protocols(self):
_PROTOCOL = self.PROTOCOL
return [_PROTOCOL.MYSQL, _PROTOCOL.MARIADB, _PROTOCOL.ORACLE, _PROTOCOL.POSTGRESQL]
@property
def can_terminate(self):
_PROTOCOL = self.PROTOCOL
if self.is_finished:
return False
if self.protocol in self.db_protocols:
return False
else:
return True
def save_replay_to_storage(self, f):
local_path = self.get_local_path()
try:
name = default_storage.save(local_path, f)
except OSError as e:
return None, e
if settings.SERVER_REPLAY_STORAGE:
from .tasks import upload_session_replay_to_external_storage
upload_session_replay_to_external_storage.delay(str(self.id))
return name, None
@classmethod
def set_sessions_active(cls, sessions_id):
data = {cls.ACTIVE_CACHE_KEY_PREFIX.format(i): i for i in sessions_id}
cache.set_many(data, timeout=5*60)
@classmethod
def get_active_sessions(cls):
return cls.objects.filter(is_finished=False)
def is_active(self):
if self.protocol in ['ssh', 'telnet', 'rdp', 'mysql']:
key = self.ACTIVE_CACHE_KEY_PREFIX.format(self.id)
return bool(cache.get(key))
return True
@property
def command_amount(self):
command_store = get_multi_command_storage()
return command_store.count(session=str(self.id))
@property
def login_from_display(self):
return self.get_login_from_display()
@classmethod
def generate_fake(cls, count=100, is_finished=True):
import random
from orgs.models import Organization
from users.models import User
from assets.models import Asset, SystemUser
from orgs.utils import get_current_org
from common.utils.random import random_datetime, random_ip
org = get_current_org()
if not org or not org.is_real():
Organization.default().change_to()
i = 0
users = User.objects.all()[:100]
assets = Asset.objects.all()[:100]
system_users = SystemUser.objects.all()[:100]
while i < count:
user_random = random.choices(users, k=10)
assets_random = random.choices(assets, k=10)
system_users = random.choices(system_users, k=10)
ziped = zip(user_random, assets_random, system_users)
sessions = []
now = timezone.now()
month_ago = now - timezone.timedelta(days=30)
for user, asset, system_user in ziped:
ip = random_ip()
date_start = random_datetime(month_ago, now)
date_end = random_datetime(date_start, date_start+timezone.timedelta(hours=2))
data = dict(
user=str(user), user_id=user.id,
asset=str(asset), asset_id=asset.id,
system_user=str(system_user), system_user_id=system_user.id,
remote_addr=ip,
date_start=date_start,
date_end=date_end,
is_finished=is_finished,
)
sessions.append(Session(**data))
cls.objects.bulk_create(sessions)
i += 10
class Meta:
db_table = "terminal_session"
ordering = ["-date_start"]
def __str__(self):
return "{0.id} of {0.user} to {0.asset}".format(self)
class Task(models.Model):
NAME_CHOICES = (
("kill_session", "Kill Session"),
)
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=128, choices=NAME_CHOICES, verbose_name=_("Name"))
args = models.CharField(max_length=1024, verbose_name=_("Args"))
terminal = models.ForeignKey(Terminal, null=True, on_delete=models.SET_NULL)
is_finished = models.BooleanField(default=False)
date_created = models.DateTimeField(auto_now_add=True)
date_finished = models.DateTimeField(null=True)
class Meta:
db_table = "terminal_task"
class CommandManager(models.Manager):
def bulk_create(self, objs, **kwargs):
resp = super().bulk_create(objs, **kwargs)
for i in objs:
post_save.send(i.__class__, instance=i, created=True)
return resp
class Command(AbstractSessionCommand):
objects = CommandManager()
class Meta:
db_table = "terminal_command"
ordering = ('-timestamp',)
class CommandStorage(CommonModelMixin):
TYPE_CHOICES = const.COMMAND_STORAGE_TYPE_CHOICES
TYPE_DEFAULTS = dict(const.REPLAY_STORAGE_TYPE_CHOICES_DEFAULT).keys()
TYPE_SERVER = const.COMMAND_STORAGE_TYPE_SERVER
name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True)
type = models.CharField(
max_length=16, choices=TYPE_CHOICES, verbose_name=_('Type'),
default=TYPE_SERVER
)
meta = EncryptJsonDictTextField(default={})
comment = models.TextField(
max_length=128, default='', blank=True, verbose_name=_('Comment')
)
def __str__(self):
return self.name
@property
def config(self):
config = self.meta
config.update({'TYPE': self.type})
return config
def in_defaults(self):
return self.type in self.TYPE_DEFAULTS
def is_valid(self):
if self.in_defaults():
return True
storage = jms_storage.get_log_storage(self.config)
return storage.ping()
def is_using(self):
return Terminal.objects.filter(command_storage=self.name).exists()
class ReplayStorage(CommonModelMixin):
TYPE_CHOICES = const.REPLAY_STORAGE_TYPE_CHOICES
TYPE_SERVER = const.REPLAY_STORAGE_TYPE_SERVER
TYPE_DEFAULTS = dict(const.REPLAY_STORAGE_TYPE_CHOICES_DEFAULT).keys()
name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True)
type = models.CharField(
max_length=16, choices=TYPE_CHOICES, verbose_name=_('Type'),
default=TYPE_SERVER
)
meta = EncryptJsonDictTextField(default={})
comment = models.TextField(
max_length=128, default='', blank=True, verbose_name=_('Comment')
)
def __str__(self):
return self.name
def convert_type(self):
s3_type_list = [const.REPLAY_STORAGE_TYPE_CEPH]
tp = self.type
if tp in s3_type_list:
tp = const.REPLAY_STORAGE_TYPE_S3
return tp
def get_extra_config(self):
extra_config = {'TYPE': self.convert_type()}
if self.type == const.REPLAY_STORAGE_TYPE_SWIFT:
extra_config.update({'signer': 'S3SignerType'})
return extra_config
@property
def config(self):
config = self.meta
extra_config = self.get_extra_config()
config.update(extra_config)
return config
def in_defaults(self):
return self.type in self.TYPE_DEFAULTS
def is_valid(self):
if self.in_defaults():
return True
storage = jms_storage.get_object_storage(self.config)
target = 'tests.py'
src = os.path.join(settings.BASE_DIR, 'common', target)
return storage.is_valid(src, target)
def is_using(self):
return Terminal.objects.filter(replay_storage=self.name).exists()

View File

@@ -0,0 +1,6 @@
from .command import *
from .session import *
from .status import *
from .storage import *
from .task import *
from .terminal import *

View File

@@ -0,0 +1,21 @@
from __future__ import unicode_literals
from django.db import models
from django.db.models.signals import post_save
from ..backends.command.models import AbstractSessionCommand
class CommandManager(models.Manager):
def bulk_create(self, objs, **kwargs):
resp = super().bulk_create(objs, **kwargs)
for i in objs:
post_save.send(i.__class__, instance=i, created=True)
return resp
class Command(AbstractSessionCommand):
objects = CommandManager()
class Meta:
db_table = "terminal_command"
ordering = ('-timestamp',)

Some files were not shown because too many files have changed in this diff Show More