perf: playbook task db save if conn timeout

This commit is contained in:
ibuler 2025-07-10 18:32:22 +08:00 committed by Bryan
parent b564bbebb3
commit c7dcf1ba59
6 changed files with 50 additions and 43 deletions

View File

@ -15,11 +15,13 @@ from common.decorators import bulk_create_decorator, bulk_update_decorator
from settings.models import LeakPasswords from settings.models import LeakPasswords
# 已设置手动 finish
@bulk_create_decorator(AccountRisk) @bulk_create_decorator(AccountRisk)
def create_risk(data): def create_risk(data):
return AccountRisk(**data) return AccountRisk(**data)
# 已设置手动 finish
@bulk_update_decorator(AccountRisk, update_fields=["details", "status"]) @bulk_update_decorator(AccountRisk, update_fields=["details", "status"])
def update_risk(risk): def update_risk(risk):
return risk return risk
@ -217,6 +219,9 @@ class CheckAccountManager(BaseManager):
"details": [{"datetime": now, 'type': 'init'}], "details": [{"datetime": now, 'type': 'init'}],
}) })
create_risk.finish()
update_risk.finish()
def pre_run(self): def pre_run(self):
super().pre_run() super().pre_run()
self.assets = self.execution.get_all_assets() self.assets = self.execution.get_all_assets()

View File

@ -30,6 +30,16 @@ common_risk_items = [
diff_items = risk_items + common_risk_items diff_items = risk_items + common_risk_items
@bulk_create_decorator(AccountRisk)
def _create_risk(self, data):
return AccountRisk(**data)
@bulk_update_decorator(AccountRisk, update_fields=["details"])
def _update_risk(self, account):
return account
def format_datetime(value): def format_datetime(value):
if isinstance(value, timezone.datetime): if isinstance(value, timezone.datetime):
return value.strftime("%Y-%m-%d %H:%M:%S") return value.strftime("%Y-%m-%d %H:%M:%S")
@ -141,25 +151,17 @@ class AnalyseAccountRisk:
found = assets_risks.get(key) found = assets_risks.get(key)
if not found: if not found:
self._create_risk(dict(**d, details=[detail])) _create_risk(dict(**d, details=[detail]))
continue continue
found.details.append(detail) found.details.append(detail)
self._update_risk(found) _update_risk(found)
@bulk_create_decorator(AccountRisk)
def _create_risk(self, data):
return AccountRisk(**data)
@bulk_update_decorator(AccountRisk, update_fields=["details"])
def _update_risk(self, account):
return account
def lost_accounts(self, asset, lost_users): def lost_accounts(self, asset, lost_users):
if not self.check_risk: if not self.check_risk:
return return
for user in lost_users: for user in lost_users:
self._create_risk( _create_risk(
dict( dict(
asset_id=str(asset.id), asset_id=str(asset.id),
username=user, username=user,
@ -176,7 +178,7 @@ class AnalyseAccountRisk:
self._analyse_item_changed(ga, d) self._analyse_item_changed(ga, d)
if not sys_found: if not sys_found:
basic = {"asset": asset, "username": d["username"], 'gathered_account': ga} basic = {"asset": asset, "username": d["username"], 'gathered_account': ga}
self._create_risk( _create_risk(
dict( dict(
**basic, **basic,
risk=RiskChoice.new_found, risk=RiskChoice.new_found,
@ -388,8 +390,6 @@ class GatherAccountsManager(AccountBasePlaybookManager):
self.update_gathered_account(ori_account, d) self.update_gathered_account(ori_account, d)
ori_found = username in ori_users ori_found = username in ori_users
need_analyser_gather_account.append((asset, ga, d, ori_found)) need_analyser_gather_account.append((asset, ga, d, ori_found))
self.create_gathered_account.finish()
self.update_gathered_account.finish()
for analysis_data in need_analyser_gather_account: for analysis_data in need_analyser_gather_account:
risk_analyser.analyse_risk(*analysis_data) risk_analyser.analyse_risk(*analysis_data)
self.update_gather_accounts_status(asset) self.update_gather_accounts_status(asset)
@ -403,6 +403,11 @@ class GatherAccountsManager(AccountBasePlaybookManager):
present=True present=True
) )
# 因为有 bulk create, bulk update, 所以这里需要 sleep 一下,等待数据同步 # 因为有 bulk create, bulk update, 所以这里需要 sleep 一下,等待数据同步
self.create_gathered_account.finish()
self.update_gathered_account.finish()
_update_risk.finish()
_create_risk.finish()
time.sleep(0.5) time.sleep(0.5)
def get_report_template(self): def get_report_template(self):

View File

@ -123,9 +123,7 @@ class BaseManager:
self.execution.summary = self.summary self.execution.summary = self.summary
self.execution.result = self.result self.execution.result = self.result
self.execution.status = self.status self.execution.status = self.status
self.execution.save()
with safe_atomic_db_connection():
self.execution.save()
def print_summary(self): def print_summary(self):
content = "\nSummery: \n" content = "\nSummery: \n"
@ -167,9 +165,10 @@ class BaseManager:
return data return data
def post_run(self): def post_run(self):
self.update_execution() with safe_atomic_db_connection():
self.print_summary() self.update_execution()
self.send_report_if_need() self.print_summary()
self.send_report_if_need()
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
self.pre_run() self.pre_run()
@ -548,7 +547,8 @@ class BasePlaybookManager(PlaybookPrepareMixin, BaseManager):
try: try:
kwargs.update({"clean_workspace": False}) kwargs.update({"clean_workspace": False})
cb = runner.run(**kwargs) cb = runner.run(**kwargs)
self.on_runner_success(runner, cb) with safe_atomic_db_connection():
self.on_runner_success(runner, cb)
except Exception as e: except Exception as e:
self.on_runner_failed(runner, e, **info) self.on_runner_failed(runner, e, **info)
finally: finally:

View File

@ -89,6 +89,8 @@ def create_activities(resource_ids, detail, detail_id, action, org_id):
for activity in activities: for activity in activities:
create_activity(activity) create_activity(activity)
create_activity.finish()
@signals.after_task_publish.connect @signals.after_task_publish.connect
def after_task_publish_for_activity_log(headers=None, body=None, **kwargs): def after_task_publish_for_activity_log(headers=None, body=None, **kwargs):

View File

@ -50,13 +50,14 @@ def get_objects(model, pks):
# 复制 django.db.close_old_connections, 因为它没有导出ide 提示有问题 # 复制 django.db.close_old_connections, 因为它没有导出ide 提示有问题
def close_old_connections(): def close_old_connections(**kwargs):
for conn in connections.all(): for conn in connections.all(initialized_only=True):
conn.close_if_unusable_or_obsolete() conn.close_if_unusable_or_obsolete()
# 这个要是在 Django 请求周期外使用的,不能影响 Django 的事务管理, 在 api 中使用会影响 api 事务
@contextmanager @contextmanager
def safe_db_connection(auto_close=True): def safe_db_connection():
close_old_connections() close_old_connections()
yield yield
close_old_connections() close_old_connections()
@ -64,19 +65,25 @@ def safe_db_connection(auto_close=True):
@contextmanager @contextmanager
def safe_atomic_db_connection(auto_close=False): def safe_atomic_db_connection(auto_close=False):
in_atomic_block = connection.in_atomic_block # 当前是否处于事务中 """
autocommit = transaction.get_autocommit() # 是否启用了自动提交 通用数据库连接管理器线程安全事务感知
created = False - 在连接不可用时主动重建连接
- 在非事务环境下自动关闭连接可选
- 不影响 Django 请求/事务周期
"""
in_atomic = connection.in_atomic_block # 当前是否在事务中
autocommit = transaction.get_autocommit()
recreated = False
try: try:
if not connection.is_usable(): if not connection.is_usable():
connection.close() connection.close()
connection.connect() connection.connect()
created = True recreated = True
yield yield
finally: finally:
# 如果不是事务中API 请求中可能需要提交事务),则关闭连接 # 只在非事务、autocommit 模式下,才考虑主动清理连接
if auto_close or (created and not in_atomic_block and autocommit): if auto_close or (recreated and not in_atomic and autocommit):
close_old_connections() close_old_connections()

View File

@ -302,16 +302,8 @@ def bulk_handle(handler, batch_size=50, timeout=0.5):
cache = [] # 缓存实例的列表 cache = [] # 缓存实例的列表
lock = threading.Lock() # 用于线程安全 lock = threading.Lock() # 用于线程安全
timer = [None] # 定时器对象,列表存储以便重置
org_id = None org_id = None
def reset_timer():
"""重置定时器"""
if timer[0] is not None:
timer[0].cancel()
timer[0] = threading.Timer(timeout, handle_remaining)
timer[0].start()
def handle_it(): def handle_it():
from orgs.utils import tmp_to_org from orgs.utils import tmp_to_org
with lock: with lock:
@ -351,17 +343,13 @@ def bulk_handle(handler, batch_size=50, timeout=0.5):
if len(cache) >= batch_size: if len(cache) >= batch_size:
handle_it() handle_it()
reset_timer()
return instance return instance
# 提交剩余实例的方法 # 提交剩余实例的方法
def handle_remaining(): def handle_remaining():
if not cache: if not cache:
return return
print("Timer expired. Saving remaining instances.") handle_it()
from orgs.utils import tmp_to_org
with tmp_to_org(org_id):
handle_it()
wrapper.finish = handle_remaining wrapper.finish = handle_remaining
return wrapper return wrapper