mirror of
				https://github.com/jumpserver/jumpserver.git
				synced 2025-10-24 17:34:04 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			80 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			80 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import datetime
 | |
| 
 | |
| from channels.generic.websocket import JsonWebsocketConsumer
 | |
| from django.utils import timezone
 | |
| from rest_framework.renderers import JSONRenderer
 | |
| 
 | |
| from common.db.utils import safe_db_connection
 | |
| from common.utils import get_logger
 | |
| from common.utils.connection import Subscription
 | |
| from terminal.const import TaskNameType
 | |
| from terminal.models import Session, Terminal
 | |
| from terminal.serializers import TaskSerializer, StatSerializer
 | |
| from .signal_handlers import component_event_chan
 | |
| 
 | |
| logger = get_logger(__name__)
 | |
| 
 | |
| 
 | |
| class TerminalTaskWebsocket(JsonWebsocketConsumer):
 | |
|     sub: Subscription = None
 | |
|     terminal: Terminal = None
 | |
| 
 | |
|     def connect(self):
 | |
|         user = self.scope["user"]
 | |
|         if user.is_authenticated and user.terminal:
 | |
|             self.accept()
 | |
|             self.terminal = user.terminal
 | |
|             self.sub = self.watch_component_event()
 | |
|         else:
 | |
|             self.close()
 | |
| 
 | |
|     def receive_json(self, content, **kwargs):
 | |
|         req_type = content.get('type')
 | |
|         if req_type == "status":
 | |
|             payload = content.get('payload')
 | |
|             self.handle_status(payload)
 | |
| 
 | |
|     def handle_status(self, content):
 | |
|         serializer = StatSerializer(data=content)
 | |
|         if not serializer.is_valid():
 | |
|             logger.error('Invalid status data: {}'.format(serializer.errors))
 | |
|             return
 | |
|         serializer.validated_data["terminal"] = self.terminal
 | |
|         session_ids = serializer.validated_data.pop('sessions', [])
 | |
|         Session.set_sessions_active(session_ids)
 | |
|         with safe_db_connection():
 | |
|             serializer.save()
 | |
| 
 | |
|     def send_tasks_msg(self, task_id=None):
 | |
|         content = self.get_terminal_tasks(task_id)
 | |
|         self.send(bytes_data=content)
 | |
| 
 | |
|     def get_terminal_tasks(self, task_id=None):
 | |
|         with safe_db_connection():
 | |
|             critical_time = timezone.now() - datetime.timedelta(minutes=10)
 | |
|             tasks = self.terminal.task_set.filter(is_finished=False, date_created__gte=critical_time)
 | |
|             if task_id:
 | |
|                 tasks = tasks.filter(id=task_id)
 | |
|             serializer = TaskSerializer(tasks, many=True)
 | |
|             return JSONRenderer().render(serializer.data)
 | |
| 
 | |
|     def watch_component_event(self):
 | |
|         # 先发一次已有的任务
 | |
|         self.send_tasks_msg()
 | |
| 
 | |
|         ws = self
 | |
| 
 | |
|         def handle_task_msg_recv(msg):
 | |
|             logger.debug('New component task msg recv: {}'.format(msg))
 | |
|             msg_type = msg.get('type')
 | |
|             payload = msg.get('payload')
 | |
|             if msg_type in TaskNameType.names:
 | |
|                 ws.send_tasks_msg(payload.get('id'))
 | |
| 
 | |
|         return component_event_chan.subscribe(handle_task_msg_recv)
 | |
| 
 | |
|     def disconnect(self, code):
 | |
|         if self.sub is None:
 | |
|             return
 | |
|         self.sub.unsubscribe()
 |