perf: Reduce the number of pub sub processing threads (#16072)

* perf: Reduce the number of pub sub processing threads

* perf: Using thread pool to process messages

---------

Co-authored-by: wangruidong <940853815@qq.com>
This commit is contained in:
fit2bot
2025-10-21 17:41:14 +08:00
committed by GitHub
parent d68babb2e1
commit 70068c9253

View File

@@ -1,16 +1,137 @@
import json
import threading
import time
import redis
from django.core.cache import cache
from redis.client import PubSub
from common.db.utils import safe_db_connection
from common.utils import get_logger
logger = get_logger(__name__)
import threading
from concurrent.futures import ThreadPoolExecutor
_PUBSUB_HUBS = {}
def _get_pubsub_hub(db=10):
hub = _PUBSUB_HUBS.get(db)
if not hub:
hub = PubSubHub(db=db)
_PUBSUB_HUBS[db] = hub
return hub
class PubSubHub:
def __init__(self, db=10):
self.db = db
self.redis = get_redis_client(db)
self.pubsub = self.redis.pubsub()
self.handlers = {}
self.lock = threading.RLock()
self.listener = None
self.running = False
self.executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix='pubsub_handler')
def __del__(self):
self.executor.shutdown(wait=True)
def start(self):
with self.lock:
if self.listener and self.listener.is_alive():
return
self.running = True
self.listener = threading.Thread(name='pubsub_listen', target=self._listen_loop, daemon=True)
self.listener.start()
def _listen_loop(self):
backoff = 1
while self.running:
try:
for msg in self.pubsub.listen():
if msg.get("type") != "message":
continue
ch = msg.get("channel")
if isinstance(ch, bytes):
ch = ch.decode()
data = msg.get("data")
try:
if isinstance(data, bytes):
item = json.loads(data.decode())
elif isinstance(data, str):
item = json.loads(data)
else:
item = data
except Exception:
item = data
# 使用线程池处理消息
future = self.executor.submit(self._dispatch, ch, msg, item)
future.add_done_callback(
lambda f: f.exception() and logger.error(f"handle pubsub msg {msg} failed: {f.exception()}"))
backoff = 1
except Exception as e:
logger.error(f'PubSub listen error: {e}')
time.sleep(backoff)
backoff = min(backoff * 2, 30)
try:
self._reconnect()
except Exception as re:
logger.error(f'PubSub reconnect error: {re}')
def _dispatch(self, ch, raw_msg, item):
with self.lock:
handler = self.handlers.get(ch)
if not handler:
return
_next, error, _complete = handler
try:
with safe_db_connection():
_next(item)
except Exception as e:
logger.error(f'Subscribe handler handle msg error: {e}')
try:
if error:
error(raw_msg, item)
except Exception:
pass
def add_subscription(self, pb, _next, error, complete):
ch = pb.ch
with self.lock:
existed = bool(self.handlers.get(ch))
self.handlers[ch] = (_next, error, complete)
try:
if not existed:
self.pubsub.subscribe(ch)
except Exception as e:
logger.error(f'Subscribe channel {ch} error: {e}')
self.start()
return Subscription(pb=pb, hub=self, ch=ch, handler=(_next, error, complete))
def remove_subscription(self, sub):
ch = sub.ch
with self.lock:
existed = self.handlers.pop(ch, None)
if existed:
try:
self.pubsub.unsubscribe(ch)
except Exception as e:
logger.warning(f'Unsubscribe {ch} error: {e}')
def _reconnect(self):
with self.lock:
channels = [ch for ch, h in self.handlers.items() if h]
try:
self.pubsub.close()
except Exception:
pass
self.redis = get_redis_client(self.db)
self.pubsub = self.redis.pubsub()
if channels:
self.pubsub.subscribe(channels)
def get_redis_client(db=0):
client = cache.client.get_client()
@@ -25,15 +146,11 @@ class RedisPubSub:
self.redis = get_redis_client(db)
def subscribe(self, _next, error=None, complete=None):
ps = self.redis.pubsub()
ps.subscribe(self.ch)
sub = Subscription(self, ps)
sub.keep_handle_msg(_next, error, complete)
return sub
hub = _get_pubsub_hub(self.db)
return hub.add_subscription(self, _next, error, complete)
def resubscribe(self, _next, error=None, complete=None):
self.redis = get_redis_client(self.db)
self.subscribe(_next, error, complete)
return self.subscribe(_next, error, complete)
def publish(self, data):
data_json = json.dumps(data)
@@ -42,85 +159,19 @@ class RedisPubSub:
class Subscription:
def __init__(self, pb: RedisPubSub, sub: PubSub):
def __init__(self, pb: RedisPubSub, hub: PubSubHub, ch: str, handler):
self.pb = pb
self.ch = pb.ch
self.sub = sub
self.ch = ch
self.hub = hub
self.handler = handler
self.unsubscribed = False
def _handle_msg(self, _next, error, complete):
"""
handle arg is the pub published
:param _next: next msg handler
:param error: error msg handler
:param complete: complete msg handler
:return:
"""
msgs = self.sub.listen()
if error is None:
error = lambda m, i: None
if complete is None:
complete = lambda: None
try:
for msg in msgs:
if msg["type"] != "message":
continue
item = None
try:
item_json = msg['data'].decode()
item = json.loads(item_json)
with safe_db_connection():
_next(item)
except Exception as e:
error(msg, item)
logger.error('Subscribe handler handle msg error: {}'.format(e))
except Exception as e:
if self.unsubscribed:
logger.debug('Subscription unsubscribed')
else:
logger.error('Consume msg error: {}'.format(e))
self.retry(_next, error, complete)
return
try:
complete()
except Exception as e:
logger.error('Complete subscribe error: {}'.format(e))
pass
try:
self.unsubscribe()
except Exception as e:
logger.error("Redis observer close error: {}".format(e))
def keep_handle_msg(self, _next, error, complete):
t = threading.Thread(target=self._handle_msg, args=(_next, error, complete))
t.daemon = True
t.start()
return t
def unsubscribe(self):
if self.unsubscribed:
return
self.unsubscribed = True
logger.info(f"Unsubscribed from channel: {self.sub}")
logger.info(f"Unsubscribed from channel: {self.ch}")
try:
self.sub.close()
self.hub.remove_subscription(self)
except Exception as e:
logger.warning(f'Unsubscribe msg error: {e}')
def retry(self, _next, error, complete):
logger.info('Retry subscribe channel: {}'.format(self.ch))
times = 0
while True:
try:
self.unsubscribe()
self.pb.resubscribe(_next, error, complete)
break
except Exception as e:
logger.error('Retry #{} {} subscribe channel error: {}'.format(times, self.ch, e))
times += 1
time.sleep(times * 2)