perf: 优化命令记录慢的问题

This commit is contained in:
xinwen
2021-02-22 18:35:53 +08:00
parent 7f42e59714
commit 3e7e01418d
9 changed files with 427 additions and 160 deletions

View File

@@ -1,53 +1,270 @@
# -*- coding: utf-8 -*-
#
from datetime import datetime
from jms_storage.es import ESStorage
from functools import reduce, partial
from itertools import groupby
import pytz
from uuid import UUID
import inspect
from django.db.models import QuerySet as DJQuerySet
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
from common.utils.common import lazyproperty
from common.utils import get_logger
from .base import CommandBase
from .models import AbstractSessionCommand
logger = get_logger(__file__)
class CommandStore(ESStorage, CommandBase):
def __init__(self, params):
super().__init__(params)
class CommandStore():
def __init__(self, config):
hosts = config.get("HOSTS")
kwargs = config.get("OTHER", {})
self.index = config.get("INDEX") or 'jumpserver'
self.doc_type = config.get("DOC_TYPE") or 'command_store'
self.es = Elasticsearch(hosts=hosts, **kwargs)
def filter(self, date_from=None, date_to=None,
user=None, asset=None, system_user=None,
input=None, session=None, risk_level=None, org_id=None):
@staticmethod
def make_data(command):
data = dict(
user=command["user"], asset=command["asset"],
system_user=command["system_user"], input=command["input"],
output=command["output"], risk_level=command["risk_level"],
session=command["session"], timestamp=command["timestamp"],
org_id=command["org_id"]
)
data["date"] = datetime.fromtimestamp(command['timestamp'], tz=pytz.UTC)
return data
if date_from is not None:
if isinstance(date_from, float):
date_from = datetime.fromtimestamp(date_from)
if date_to is not None:
if isinstance(date_to, float):
date_to = datetime.fromtimestamp(date_to)
try:
data = super().filter(date_from=date_from, date_to=date_to,
user=user, asset=asset, system_user=system_user,
input=input, session=session,
risk_level=risk_level, org_id=org_id)
except Exception as e:
logger.error(e, exc_info=True)
return []
else:
return AbstractSessionCommand.from_multi_dict(
[item["_source"] for item in data["hits"] if item]
def bulk_save(self, command_set, raise_on_error=True):
actions = []
for command in command_set:
data = dict(
_index=self.index,
_type=self.doc_type,
_source=self.make_data(command),
)
actions.append(data)
return bulk(self.es, actions, index=self.index, raise_on_error=raise_on_error)
def count(self, date_from=None, date_to=None, user=None, asset=None,
system_user=None, input=None, session=None):
def save(self, command):
"""
保存命令到数据库
"""
data = self.make_data(command)
return self.es.index(index=self.index, doc_type=self.doc_type, body=data)
def filter(self, query: dict, from_=None, size=None, sort=None):
body = self.get_query_body(**query)
data = self.es.search(
index=self.index, doc_type=self.doc_type, body=body, from_=from_, size=size,
sort=sort
)
return AbstractSessionCommand.from_multi_dict(
[item['_source'] for item in data['hits']['hits'] if item]
)
def count(self, **query):
body = self.get_query_body(**query)
data = self.es.count(index=self.index, doc_type=self.doc_type, body=body)
return data["count"]
def __getattr__(self, item):
return getattr(self.es, item)
def all(self):
"""返回所有数据"""
raise NotImplementedError("Not support")
def ping(self):
try:
count = super().count(
date_from=date_from, date_to=date_to, user=user, asset=asset,
system_user=system_user, input=input, session=session
)
except Exception as e:
logger.error(e, exc_info=True)
return 0
else:
return count
return self.es.ping()
except Exception:
return False
@staticmethod
def get_query_body(**kwargs):
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = str(v) if isinstance(v, UUID) else v
kwargs = new_kwargs
exact_fields = {}
match_fields = {'session', 'input', 'org_id', 'risk_level', 'user', 'asset', 'system_user'}
match = {}
exact = {}
for k, v in kwargs.items():
if k in exact_fields:
exact[k] = v
elif k in match_fields:
match[k] = v
# 处理时间
timestamp__gte = kwargs.get('timestamp__gte')
timestamp__lte = kwargs.get('timestamp__lte')
timestamp_range = {}
if timestamp__gte:
timestamp_range['gte'] = timestamp__gte
if timestamp__lte:
timestamp_range['lte'] = timestamp__lte
# 处理组织
must_not = []
org_id = match.get('org_id')
if org_id == '':
match.pop('org_id')
must_not.append({'wildcard': {'org_id': '*'}})
# 构建 body
body = {
'query': {
'bool': {
'must': [
{'match': {k: v}} for k, v in match.items()
],
'must_not': must_not,
'filter': [
{
'term': {k: v}
} for k, v in exact.items()
] + [
{
'range': {
'timestamp': timestamp_range
}
}
]
}
},
}
return body
class QuerySet(DJQuerySet):
_method_calls = None
_storage = None
_command_store_config = None
_slice = None # (from_, size)
default_days_ago = 5
max_result_window = 10000
def __init__(self, command_store_config):
self._method_calls = []
self._command_store_config = command_store_config
self._storage = CommandStore(command_store_config)
@lazyproperty
def _grouped_method_calls(self):
_method_calls = {k: list(v) for k, v in groupby(self._method_calls, lambda x: x[0])}
return _method_calls
@lazyproperty
def _filter_kwargs(self):
_method_calls = self._grouped_method_calls
filter_calls = _method_calls.get('filter')
if not filter_calls:
return {}
names, multi_args, multi_kwargs = zip(*filter_calls)
kwargs = reduce(lambda x, y: {**x, **y}, multi_kwargs, {})
striped_kwargs = {}
for k, v in kwargs.items():
k = k.replace('__exact', '')
k = k.replace('__startswith', '')
k = k.replace('__icontains', '')
striped_kwargs[k] = v
return striped_kwargs
@lazyproperty
def _sort(self):
order_by = self._grouped_method_calls.get('order_by')
if order_by:
for call in reversed(order_by):
fields = call[1]
if fields:
field = fields[-1]
if field.startswith('-'):
direction = 'desc'
else:
direction = 'asc'
field = field.lstrip('-+')
sort = f'{field}:{direction}'
return sort
def __execute(self):
_filter_kwargs = self._filter_kwargs
_sort = self._sort
from_, size = self._slice or (None, None)
data = self._storage.filter(_filter_kwargs, from_=from_, size=size, sort=_sort)
return data
def __stage_method_call(self, item, *args, **kwargs):
_clone = self.__clone()
_clone._method_calls.append((item, args, kwargs))
return _clone
def __clone(self):
uqs = QuerySet(self._command_store_config)
uqs._method_calls = self._method_calls.copy()
uqs._slice = self._slice
return uqs
def count(self, limit_to_max_result_window=True):
filter_kwargs = self._filter_kwargs
count = self._storage.count(**filter_kwargs)
if limit_to_max_result_window:
count = min(count, self.max_result_window)
return count
def __getattribute__(self, item):
if any((
item.startswith('__'),
item in QuerySet.__dict__,
)):
return object.__getattribute__(self, item)
origin_attr = object.__getattribute__(self, item)
if not inspect.ismethod(origin_attr):
return origin_attr
attr = partial(self.__stage_method_call, item)
return attr
def __getitem__(self, item):
max_window = self.max_result_window
if isinstance(item, slice):
if self._slice is None:
clone = self.__clone()
from_ = item.start or 0
if item.stop is None:
size = 10
else:
size = item.stop - from_
if from_ + size > max_window:
if from_ >= max_window:
from_ = max_window
size = 0
else:
size = max_window - from_
clone._slice = (from_, size)
return clone
return self.__execute()[item]
def __repr__(self):
return self.__execute().__repr__()
def __iter__(self):
return iter(self.__execute())
def __len__(self):
return self.count()