diff --git a/apps/common/cache.py b/apps/common/cache.py new file mode 100644 index 000000000..f7c0b16cd --- /dev/null +++ b/apps/common/cache.py @@ -0,0 +1,187 @@ +import json +from django.core.cache import cache + +from common.utils.lock import DistributedLock +from common.utils import lazyproperty +from common.utils import get_logger + +logger = get_logger(__file__) + + +class CacheFieldBase: + field_type = str + + def __init__(self, queryset=None, compute_func_name=None): + assert None in (queryset, compute_func_name), f'queryset and compute_func_name can only have one' + self.compute_func_name = compute_func_name + self.queryset = queryset + + +class CharField(CacheFieldBase): + field_type = str + + +class IntegerField(CacheFieldBase): + field_type = int + + +class CacheBase(type): + def __new__(cls, name, bases, attrs: dict): + to_update = {} + field_desc_mapper = {} + + for k, v in attrs.items(): + if isinstance(v, CacheFieldBase): + desc = CacheValueDesc(k, v) + to_update[k] = desc + field_desc_mapper[k] = desc + + attrs.update(to_update) + attrs['field_desc_mapper'] = field_desc_mapper + return type.__new__(cls, name, bases, attrs) + + +class Cache(metaclass=CacheBase): + field_desc_mapper: dict + timeout = None + + def __init__(self): + self._data = None + + @lazyproperty + def key_suffix(self): + return self.get_key_suffix() + + @property + def key_prefix(self): + clz = self.__class__ + return f'cache.{clz.__module__}.{clz.__name__}' + + @property + def key(self): + return f'{self.key_prefix}.{self.key_suffix}' + + @property + def data(self): + if self._data is None: + data = self.get_data() + if data is None: + # 缓存中没有数据时,去数据库获取 + self.compute_and_set_all_data() + return self._data + + def get_data(self) -> dict: + data = cache.get(self.key) + logger.debug(f'CACHE: get {self.key} = {data}') + if data is not None: + data = json.loads(data) + self._data = data + return data + + def set_data(self, data): + self._data = data + to_json = json.dumps(data) + logger.info(f'CACHE: set {self.key} = {to_json}, timeout={self.timeout}') + cache.set(self.key, to_json, timeout=self.timeout) + + def _compute_data(self, *fields): + field_descs = [] + if not fields: + field_descs = self.field_desc_mapper.values() + else: + for field in fields: + assert field in self.field_desc_mapper, f'{field} is not a valid field' + field_descs.append(self.field_desc_mapper[field]) + data = { + field_desc.field_name: field_desc.compute_value(self) + for field_desc in field_descs + } + return data + + def compute_and_set_all_data(self, computed_data: dict = None): + """ + TODO 怎样防止并发更新全部数据,浪费数据库资源 + """ + uncomputed_keys = () + if computed_data: + computed_keys = computed_data.keys() + all_keys = self.field_desc_mapper.keys() + uncomputed_keys = all_keys - computed_keys + else: + computed_data = {} + data = self._compute_data(*uncomputed_keys) + data.update(computed_data) + self.set_data(data) + return data + + def refresh_part_data_with_lock(self, refresh_data): + with DistributedLock(name=f'{self.key}.refresh'): + data = self.get_data() + if data is not None: + data.update(refresh_data) + self.set_data(data) + return data + + def refresh(self, *fields): + if not fields: + # 没有指定 field 要刷新所有的值 + self.compute_and_set_all_data() + return + + data = self.get_data() + if data is None: + # 缓存中没有数据,设置所有的值 + self.compute_and_set_all_data() + return + + refresh_data = self._compute_data(*fields) + if not self.refresh_part_data_with_lock(refresh_data): + # 刷新部分失败,缓存中没有数据,更新所有的值 + self.compute_and_set_all_data(refresh_data) + return + + def get_key_suffix(self): + raise NotImplementedError + + def reload(self): + self._data = None + + def delete(self): + self._data = None + logger.info(f'CACHE: delete {self.key}') + cache.delete(self.key) + + +class CacheValueDesc: + def __init__(self, field_name, field_type: CacheFieldBase): + self.field_name = field_name + self.field_type = field_type + self._data = None + + def __repr__(self): + clz = self.__class__ + return f'<{clz.__name__} {self.field_name} {self.field_type}>' + + def __get__(self, instance: Cache, owner): + if instance is None: + return self + if self.field_name not in instance.data: + instance.refresh(self.field_name) + value = instance.data[self.field_name] + return value + + def compute_value(self, instance: Cache): + if self.field_type.queryset is not None: + new_value = self.field_type.queryset.count() + else: + compute_func_name = self.field_type.compute_func_name + if not compute_func_name: + compute_func_name = f'compute_{self.field_name}' + compute_func = getattr(instance, compute_func_name, None) + assert compute_func is not None, \ + f'Define `{compute_func_name}` method in {instance.__class__}' + new_value = compute_func() + + new_value = self.field_type.field_type(new_value) + logger.info(f'CACHE: compute {instance.key}.{self.field_name} = {new_value}') + return new_value diff --git a/apps/common/const/signals.py b/apps/common/const/signals.py index b28c1310b..5d35518ab 100644 --- a/apps/common/const/signals.py +++ b/apps/common/const/signals.py @@ -12,3 +12,6 @@ PRE_REMOVE = 'pre_remove' POST_REMOVE = 'post_remove' PRE_CLEAR = 'pre_clear' POST_CLEAR = 'post_clear' + +POST_PREFIX = 'post' +PRE_PREFIX = 'pre' diff --git a/apps/common/mixins/serializers.py b/apps/common/mixins/serializers.py index d9df17e1d..020af68d7 100644 --- a/apps/common/mixins/serializers.py +++ b/apps/common/mixins/serializers.py @@ -124,6 +124,22 @@ class BulkListSerializerMixin(object): return ret + def create(self, validated_data): + ModelClass = self.child.Meta.model + use_model_bulk_create = getattr(self.child.Meta, 'use_model_bulk_create', False) + model_bulk_create_kwargs = getattr(self.child.Meta, 'model_bulk_create_kwargs', {}) + + if use_model_bulk_create: + to_create = [ + ModelClass(**attrs) for attrs in validated_data + ] + objs = ModelClass._default_manager.bulk_create( + to_create, **model_bulk_create_kwargs + ) + return objs + else: + return super().create(validated_data) + class BaseDynamicFieldsPlugin: def __init__(self, serializer): diff --git a/apps/common/tasks.py b/apps/common/tasks.py index eeeee7214..715be5103 100644 --- a/apps/common/tasks.py +++ b/apps/common/tasks.py @@ -4,7 +4,6 @@ from celery import shared_task from .utils import get_logger - logger = get_logger(__file__) diff --git a/apps/common/utils/lock.py b/apps/common/utils/lock.py index 9041a2578..04ee1520f 100644 --- a/apps/common/utils/lock.py +++ b/apps/common/utils/lock.py @@ -1,4 +1,5 @@ from functools import wraps +import threading from redis_lock import Lock as RedisLock from redis import Redis @@ -35,11 +36,16 @@ class DistributedLock(RedisLock): self._blocking = blocking def __enter__(self): + thread_id = threading.current_thread().ident + logger.debug(f'DISTRIBUTED_LOCK: attempt to acquire ...') acquired = self.acquire(blocking=self._blocking) if self._blocking and not acquired: + logger.debug(f'DISTRIBUTED_LOCK: was not acquired , but blocking=True') raise EnvironmentError("Lock wasn't acquired, but blocking=True") if not acquired: + logger.debug(f'DISTRIBUTED_LOCK: acquire failed') raise AcquireFailed + logger.debug(f'DISTRIBUTED_LOCK: acquire ok') return self def __exit__(self, exc_type=None, exc_value=None, traceback=None): diff --git a/apps/orgs/api.py b/apps/orgs/api.py index 6b7b50401..1a6545f0b 100644 --- a/apps/orgs/api.py +++ b/apps/orgs/api.py @@ -75,11 +75,6 @@ class OrgMemberRelationBulkViewSet(JMSBulkRelationModelViewSet): filterset_class = OrgMemberRelationFilterSet search_fields = ('user__name', 'user__username', 'org__name') - def perform_bulk_create(self, serializer): - data = serializer.validated_data - relations = [OrganizationMember(**i) for i in data] - OrganizationMember.objects.bulk_create(relations, ignore_conflicts=True) - def perform_bulk_destroy(self, queryset): objs = list(queryset.all().prefetch_related('user', 'org')) queryset.delete() diff --git a/apps/orgs/cache.py b/apps/orgs/cache.py new file mode 100644 index 000000000..c77dd988f --- /dev/null +++ b/apps/orgs/cache.py @@ -0,0 +1,34 @@ +from django.db.transaction import on_commit + +from common.cache import * +from .utils import current_org, tmp_to_org +from .tasks import refresh_org_cache_task +from orgs.models import Organization + + +class OrgRelatedCache(Cache): + + def __init__(self): + super().__init__() + self.current_org = Organization.get_instance(current_org.id) + + def get_current_org(self): + """ + 暴露给子类控制组织的回调 + 1. 在交互式环境下能控制组织 + 2. 在 celery 任务下能控制组织 + """ + return self.current_org + + def refresh(self, *fields): + with tmp_to_org(self.get_current_org()): + return super().refresh(*fields) + + def refresh_async(self, *fields): + """ + 在事务提交之后再发送信号,防止因事务的隔离性导致未获得最新的数据 + """ + def func(): + logger.info(f'CACHE: Send refresh task {self}.{fields}') + refresh_org_cache_task.delay(self, *fields) + on_commit(func) diff --git a/apps/orgs/caches.py b/apps/orgs/caches.py new file mode 100644 index 000000000..462b13594 --- /dev/null +++ b/apps/orgs/caches.py @@ -0,0 +1,46 @@ +from .cache import OrgRelatedCache, IntegerField +from users.models import UserGroup, User +from assets.models import Node, AdminUser, SystemUser, Domain, Gateway +from applications.models import Application +from perms.models import AssetPermission, ApplicationPermission +from .models import OrganizationMember + + +class OrgResourceStatisticsCache(OrgRelatedCache): + users_amount = IntegerField() + groups_amount = IntegerField(queryset=UserGroup.objects) + + assets_amount = IntegerField() + nodes_amount = IntegerField(queryset=Node.objects) + admin_users_amount = IntegerField(queryset=AdminUser.objects) + system_users_amount = IntegerField(queryset=SystemUser.objects) + domains_amount = IntegerField(queryset=Domain.objects) + gateways_amount = IntegerField(queryset=Gateway.objects) + + applications_amount = IntegerField(queryset=Application.objects) + + asset_perms_amount = IntegerField(queryset=AssetPermission.objects) + app_perms_amount = IntegerField(queryset=ApplicationPermission.objects) + + def __init__(self, org): + super().__init__() + self.org = org + + def get_key_suffix(self): + return f'' + + def get_current_org(self): + return self.org + + def compute_users_amount(self): + if self.org.is_real(): + users_amount = OrganizationMember.objects.values( + 'user_id' + ).filter(org_id=self.org.id).distinct().count() + else: + users_amount = User.objects.all().distinct().count() + return users_amount + + def compute_assets_amount(self): + node = Node.org_root() + return node.assets_amount diff --git a/apps/orgs/models.py b/apps/orgs/models.py index 715067b1c..1fbae50ce 100644 --- a/apps/orgs/models.py +++ b/apps/orgs/models.py @@ -7,7 +7,7 @@ from django.db.models import signals from django.db.models import Q from django.utils.translation import ugettext_lazy as _ -from common.utils import is_uuid +from common.utils import is_uuid, lazyproperty from common.const import choices from common.db.models import ChoiceSet @@ -215,6 +215,33 @@ class Organization(models.Model): from .utils import set_current_org set_current_org(self) + @lazyproperty + def resource_statistics_cache(self): + from .caches import OrgResourceStatisticsCache + return OrgResourceStatisticsCache(self) + + def get_total_resources_amount(self): + from django.apps import apps + from orgs.mixins.models import OrgModelMixin + summary = {'users.Members': self.members.all().count()} + for app_name, app_config in apps.app_configs.items(): + models_cls = app_config.get_models() + for model in models_cls: + if not issubclass(model, OrgModelMixin): + continue + key = '{}.{}'.format(app_name, model.__name__) + summary[key] = self.get_resource_amount(model) + return summary + + def get_resource_amount(self, resource_model): + from .utils import tmp_to_org + from .mixins.models import OrgModelMixin + + if not issubclass(resource_model, OrgModelMixin): + return 0 + with tmp_to_org(self): + return resource_model.objects.all().count() + def _convert_to_uuid_set(users): rst = set() diff --git a/apps/orgs/serializers.py b/apps/orgs/serializers.py index 432670bb6..d10963531 100644 --- a/apps/orgs/serializers.py +++ b/apps/orgs/serializers.py @@ -10,18 +10,37 @@ from common.db.models import concated_display as display from .models import Organization, OrganizationMember, ROLE +class ResourceStatisticsSerializer(serializers.Serializer): + users_amount = serializers.IntegerField(required=False) + groups_amount = serializers.IntegerField(required=False) + + assets_amount = serializers.IntegerField(required=False) + nodes_amount = serializers.IntegerField(required=False) + admin_users_amount = serializers.IntegerField(required=False) + system_users_amount = serializers.IntegerField(required=False) + domains_amount = serializers.IntegerField(required=False) + gateways_amount = serializers.IntegerField(required=False) + + applications_amount = serializers.IntegerField(required=False) + asset_perms_amount = serializers.IntegerField(required=False) + app_perms_amount = serializers.IntegerField(required=False) + + class OrgSerializer(ModelSerializer): users = serializers.PrimaryKeyRelatedField(many=True, queryset=User.objects.all(), write_only=True, required=False) admins = serializers.PrimaryKeyRelatedField(many=True, queryset=User.objects.all(), write_only=True, required=False) auditors = serializers.PrimaryKeyRelatedField(many=True, queryset=User.objects.all(), write_only=True, required=False) + resource_statistics = ResourceStatisticsSerializer(source='resource_statistics_cache') + class Meta: model = Organization list_serializer_class = AdaptedBulkListSerializer fields_mini = ['id', 'name'] fields_small = fields_mini + [ - 'created_by', 'date_created', 'comment' + 'created_by', 'date_created', 'comment', 'resource_statistics' ] + fields_m2m = ['users', 'admins', 'auditors'] fields = fields_small + fields_m2m read_only_fields = ['created_by', 'date_created'] @@ -60,6 +79,8 @@ class OrgMemberSerializer(BulkModelSerializer): class Meta: model = OrganizationMember fields = ('id', 'org', 'user', 'role', 'org_display', 'user_display', 'role_display') + use_model_bulk_create = True + model_bulk_create_kwargs = {'ignore_conflicts': True} def get_unique_together_validators(self): if self.parent: diff --git a/apps/orgs/signals_handler.py b/apps/orgs/signals_handler.py index adfd9d373..5f918ee21 100644 --- a/apps/orgs/signals_handler.py +++ b/apps/orgs/signals_handler.py @@ -4,7 +4,7 @@ from collections import defaultdict from functools import partial from django.db.models.signals import m2m_changed -from django.db.models.signals import post_save +from django.db.models.signals import post_save, pre_delete from django.dispatch import receiver from orgs.utils import tmp_to_org @@ -12,7 +12,10 @@ from .models import Organization, OrganizationMember from .hands import set_current_org, Node, get_current_org from perms.models import (AssetPermission, ApplicationPermission) from users.models import UserGroup, User -from common.const.signals import PRE_REMOVE, POST_REMOVE +from applications.models import Application +from assets.models import Asset, AdminUser, SystemUser, Domain, Gateway +from common.const.signals import PRE_REMOVE, POST_REMOVE, POST_PREFIX +from .caches import OrgResourceStatisticsCache @receiver(post_save, sender=Organization) @@ -106,3 +109,72 @@ def on_org_user_changed(action, instance, reverse, pk_set, **kwargs): leaved_users = set(pk_set) - set(org.members.filter(id__in=user_pk_set).values_list('id', flat=True)) _clear_users_from_org(org, leaved_users) + + +# 缓存相关 +# ----------------------------------------------------- + +def refresh_user_amount_on_user_create_or_delete(user_id): + orgs = Organization.objects.filter(m2m_org_members__user_id=user_id).distinct() + for org in orgs: + org_cache = OrgResourceStatisticsCache(org) + org_cache.refresh_async('users_amount') + + +@receiver(post_save, sender=User) +def on_user_create(sender, instance, created, **kwargs): + if created: + refresh_user_amount_on_user_create_or_delete(instance.id) + + +@receiver(pre_delete, sender=User) +def on_user_delete(sender, instance, **kwargs): + refresh_user_amount_on_user_create_or_delete(instance.id) + + +@receiver(m2m_changed, sender=OrganizationMember) +def on_org_user_changed(sender, action, instance, reverse, pk_set, **kwargs): + if not action.startswith(POST_PREFIX): + return + + if reverse: + orgs = Organization.objects.filter(id__in=pk_set) + else: + orgs = [instance] + + for org in orgs: + org_cache = OrgResourceStatisticsCache(org) + org_cache.refresh_async('users_amount') + + +class OrgResourceStatisticsRefreshUtil: + model_cache_field_mapper = { + ApplicationPermission: 'app_perms_amount', + AssetPermission: 'asset_perms_amount', + Application: 'applications_amount', + Gateway: 'gateways_amount', + Domain: 'domains_amount', + SystemUser: 'system_users_amount', + AdminUser: 'admin_users_amount', + Node: 'nodes_amount', + Asset: 'assets_amount', + UserGroup: 'groups_amount', + } + + @classmethod + def refresh_if_need(cls, instance): + cache_field_name = cls.model_cache_field_mapper.get(type(instance)) + if cache_field_name: + org_cache = OrgResourceStatisticsCache(instance.org) + org_cache.refresh_async(cache_field_name) + + +@receiver(post_save) +def on_post_save_refresh_org_resource_statistics_cache(sender, instance, created, **kwargs): + if created: + OrgResourceStatisticsRefreshUtil.refresh_if_need(instance) + + +@receiver(pre_delete) +def on_pre_delete_refresh_org_resource_statistics_cache(sender, instance, **kwargs): + OrgResourceStatisticsRefreshUtil.refresh_if_need(instance) diff --git a/apps/orgs/tasks.py b/apps/orgs/tasks.py new file mode 100644 index 000000000..a33456913 --- /dev/null +++ b/apps/orgs/tasks.py @@ -0,0 +1,11 @@ +from celery import shared_task + +from common.utils import get_logger + +logger = get_logger(__file__) + + +@shared_task +def refresh_org_cache_task(cache, *fields): + logger.info(f'CACHE: refresh {cache.key}.{fields}') + cache.refresh(*fields) diff --git a/apps/users/models/user.py b/apps/users/models/user.py index 20701acac..36e5eb330 100644 --- a/apps/users/models/user.py +++ b/apps/users/models/user.py @@ -18,7 +18,7 @@ from django.shortcuts import reverse from common.local import LOCAL_DYNAMIC_SETTINGS from orgs.utils import current_org -from orgs.models import OrganizationMember +from orgs.models import OrganizationMember, Organization from common.utils import date_expired_default, get_logger, lazyproperty from common import fields from common.const import choices @@ -327,7 +327,8 @@ class RoleMixin: def remove(self): if not current_org.is_real(): return - OrganizationMember.objects.remove_users(current_org, [self]) + org = Organization.get_instance(current_org.id) + OrganizationMember.objects.remove_users(org, [self]) @classmethod def get_super_admins(cls):