Compare commits

...

2 Commits
dev ... v5

Author SHA1 Message Date
feng
f362163af1 perf: remove gpt model 2025-12-16 13:19:45 +08:00
fit2bot
5f1ba56e56 Merge pull request #16094 from jumpserver/pr@dev@chat_model
perf: Add open ui chat model
2025-12-10 10:43:14 +08:00
28 changed files with 265 additions and 195 deletions

View File

@@ -186,6 +186,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._init_field_choices() self._init_field_choices()
self._extract_accounts() self._extract_accounts()
self._set_platform()
def _extract_accounts(self): def _extract_accounts(self):
if not getattr(self, 'initial_data', None): if not getattr(self, 'initial_data', None):
@@ -217,6 +218,21 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
protocols_data = [{'name': p.name, 'port': p.port} for p in protocols] protocols_data = [{'name': p.name, 'port': p.port} for p in protocols]
self.initial_data['protocols'] = protocols_data self.initial_data['protocols'] = protocols_data
def _set_platform(self):
if not hasattr(self, 'initial_data'):
return
platform_id = self.initial_data.get('platform')
if not platform_id:
return
if isinstance(platform_id, int) or str(platform_id).isdigit() or not isinstance(platform_id, str):
return
platform = Platform.objects.filter(name=platform_id).first()
if not platform:
return
self.initial_data['platform'] = platform.id
def _init_field_choices(self): def _init_field_choices(self):
request = self.context.get('request') request = self.context.get('request')
if not request: if not request:
@@ -265,8 +281,10 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
if not platform_id and self.instance: if not platform_id and self.instance:
platform = self.instance.platform platform = self.instance.platform
else: elif isinstance(platform_id, int):
platform = Platform.objects.filter(id=platform_id).first() platform = Platform.objects.filter(id=platform_id).first()
else:
platform = Platform.objects.filter(name=platform_id).first()
if not platform: if not platform:
raise serializers.ValidationError({'platform': _("Platform not exist")}) raise serializers.ValidationError({'platform': _("Platform not exist")})
@@ -297,6 +315,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
def is_valid(self, raise_exception=False): def is_valid(self, raise_exception=False):
self._set_protocols_default() self._set_protocols_default()
self._set_platform()
return super().is_valid(raise_exception=raise_exception) return super().is_valid(raise_exception=raise_exception)
def validate_protocols(self, protocols_data): def validate_protocols(self, protocols_data):

View File

@@ -67,6 +67,7 @@ class UserLoginMFAView(mixins.AuthMixin, FormView):
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
user = self.get_user_from_session() user = self.get_user_from_session()
mfa_context = self.get_user_mfa_context(user) mfa_context = self.get_user_mfa_context(user)
print(mfa_context)
kwargs.update(mfa_context) kwargs.update(mfa_context)
return kwargs return kwargs

View File

@@ -701,15 +701,7 @@ class Config(dict):
'CHAT_AI_ENABLED': False, 'CHAT_AI_ENABLED': False,
'CHAT_AI_METHOD': 'api', 'CHAT_AI_METHOD': 'api',
'CHAT_AI_EMBED_URL': '', 'CHAT_AI_EMBED_URL': '',
'CHAT_AI_TYPE': 'gpt', 'CHAT_AI_PROVIDERS': [],
'GPT_BASE_URL': '',
'GPT_API_KEY': '',
'GPT_PROXY': '',
'GPT_MODEL': 'gpt-4o-mini',
'DEEPSEEK_BASE_URL': '',
'DEEPSEEK_API_KEY': '',
'DEEPSEEK_PROXY': '',
'DEEPSEEK_MODEL': 'deepseek-chat',
'VIRTUAL_APP_ENABLED': False, 'VIRTUAL_APP_ENABLED': False,
'FILE_UPLOAD_SIZE_LIMIT_MB': 200, 'FILE_UPLOAD_SIZE_LIMIT_MB': 200,

View File

@@ -241,15 +241,7 @@ ASSET_SIZE = 'small'
CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED
CHAT_AI_METHOD = CONFIG.CHAT_AI_METHOD CHAT_AI_METHOD = CONFIG.CHAT_AI_METHOD
CHAT_AI_EMBED_URL = CONFIG.CHAT_AI_EMBED_URL CHAT_AI_EMBED_URL = CONFIG.CHAT_AI_EMBED_URL
CHAT_AI_TYPE = CONFIG.CHAT_AI_TYPE CHAT_AI_DEFAULT_PROVIDER = CONFIG.CHAT_AI_DEFAULT_PROVIDER
GPT_BASE_URL = CONFIG.GPT_BASE_URL
GPT_API_KEY = CONFIG.GPT_API_KEY
GPT_PROXY = CONFIG.GPT_PROXY
GPT_MODEL = CONFIG.GPT_MODEL
DEEPSEEK_BASE_URL = CONFIG.DEEPSEEK_BASE_URL
DEEPSEEK_API_KEY = CONFIG.DEEPSEEK_API_KEY
DEEPSEEK_PROXY = CONFIG.DEEPSEEK_PROXY
DEEPSEEK_MODEL = CONFIG.DEEPSEEK_MODEL
VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED
@@ -269,3 +261,5 @@ TOOL_USER_ENABLED = CONFIG.TOOL_USER_ENABLED
SUGGESTION_LIMIT = CONFIG.SUGGESTION_LIMIT SUGGESTION_LIMIT = CONFIG.SUGGESTION_LIMIT
MCP_ENABLED = CONFIG.MCP_ENABLED MCP_ENABLED = CONFIG.MCP_ENABLED
CHAT_AI_PROVIDERS = CONFIG.CHAT_AI_PROVIDERS

View File

@@ -1,98 +1,10 @@
import httpx
import openai
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from rest_framework import status
from rest_framework.generics import GenericAPIView
from rest_framework.response import Response
from common.api import JMSModelViewSet from common.api import JMSModelViewSet
from common.permissions import IsValidUser, OnlySuperUser from common.permissions import IsValidUser, OnlySuperUser
from .. import serializers from .. import serializers
from ..const import ChatAITypeChoices
from ..models import ChatPrompt from ..models import ChatPrompt
from ..prompt import DefaultChatPrompt from ..prompt import DefaultChatPrompt
class ChatAITestingAPI(GenericAPIView):
serializer_class = serializers.ChatAISettingSerializer
rbac_perms = {
'POST': 'settings.change_chatai'
}
def get_config(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
data = self.serializer_class().data
data.update(serializer.validated_data)
for k, v in data.items():
if v:
continue
# 页面没有传递值, 从 settings 中获取
data[k] = getattr(settings, k, None)
return data
def post(self, request):
config = self.get_config(request)
chat_ai_enabled = config['CHAT_AI_ENABLED']
if not chat_ai_enabled:
return Response(
status=status.HTTP_400_BAD_REQUEST,
data={'msg': _('Chat AI is not enabled')}
)
tp = config['CHAT_AI_TYPE']
if tp == ChatAITypeChoices.gpt:
url = config['GPT_BASE_URL']
api_key = config['GPT_API_KEY']
proxy = config['GPT_PROXY']
model = config['GPT_MODEL']
else:
url = config['DEEPSEEK_BASE_URL']
api_key = config['DEEPSEEK_API_KEY']
proxy = config['DEEPSEEK_PROXY']
model = config['DEEPSEEK_MODEL']
kwargs = {
'base_url': url or None,
'api_key': api_key,
}
try:
if proxy:
kwargs['http_client'] = httpx.Client(
proxies=proxy,
transport=httpx.HTTPTransport(local_address='0.0.0.0')
)
client = openai.OpenAI(**kwargs)
ok = False
error = ''
client.chat.completions.create(
messages=[
{
"role": "user",
"content": "Say this is a test",
}
],
model=model,
)
ok = True
except openai.APIConnectionError as e:
error = str(e.__cause__) # an underlying Exception, likely raised within httpx.
except openai.APIStatusError as e:
error = str(e.message)
except Exception as e:
ok, error = False, str(e)
if ok:
_status, msg = status.HTTP_200_OK, _('Test success')
else:
_status, msg = status.HTTP_400_BAD_REQUEST, error
return Response(status=_status, data={'msg': msg})
class ChatPromptViewSet(JMSModelViewSet): class ChatPromptViewSet(JMSModelViewSet):
serializer_classes = { serializer_classes = {
'default': serializers.ChatPromptSerializer, 'default': serializers.ChatPromptSerializer,

View File

@@ -154,7 +154,10 @@ class SettingsApi(generics.RetrieveUpdateAPIView):
def parse_serializer_data(self, serializer): def parse_serializer_data(self, serializer):
data = [] data = []
fields = self.get_fields() fields = self.get_fields()
encrypted_items = [name for name, field in fields.items() if field.write_only] encrypted_items = [
name for name, field in fields.items()
if field.write_only or getattr(field, 'encrypted', False)
]
category = self.request.query_params.get('category', '') category = self.request.query_params.get('category', '')
for name, value in serializer.validated_data.items(): for name, value in serializer.validated_data.items():
encrypted = name in encrypted_items encrypted = name in encrypted_items

View File

@@ -14,18 +14,5 @@ class ChatAIMethodChoices(TextChoices):
class ChatAITypeChoices(TextChoices): class ChatAITypeChoices(TextChoices):
gpt = 'gpt', 'GPT' openai = 'openai', 'Openai'
deep_seek = 'deep-seek', 'DeepSeek' ollama = 'ollama', 'Ollama'
class GPTModelChoices(TextChoices):
gpt_4o_mini = 'gpt-4o-mini', 'gpt-4o-mini'
gpt_4o = 'gpt-4o', 'gpt-4o'
o3_mini = 'o3-mini', 'o3-mini'
o1_mini = 'o1-mini', 'o1-mini'
o1 = 'o1', 'o1'
class DeepSeekModelChoices(TextChoices):
deepseek_chat = 'deepseek-chat', 'DeepSeek-V3'
deepseek_reasoner = 'deepseek-reasoner', 'DeepSeek-R1'

View File

@@ -1,6 +1,7 @@
import json import json
import os import os
import shutil import shutil
from typing import Any, Dict, List
from django.conf import settings from django.conf import settings
from django.core.files.base import ContentFile from django.core.files.base import ContentFile
@@ -14,7 +15,6 @@ from rest_framework.utils.encoders import JSONEncoder
from common.db.models import JMSBaseModel from common.db.models import JMSBaseModel
from common.db.utils import Encryptor from common.db.utils import Encryptor
from common.utils import get_logger from common.utils import get_logger
from .const import ChatAITypeChoices
from .signals import setting_changed from .signals import setting_changed
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -196,20 +196,25 @@ class ChatPrompt(JMSBaseModel):
return self.name return self.name
def get_chatai_data(): def get_chatai_data() -> Dict[str, Any]:
data = { raw_providers = settings.CHAT_AI_PROVIDERS
'url': settings.GPT_BASE_URL, providers: List[dict] = [p for p in raw_providers if isinstance(p, dict)]
'api_key': settings.GPT_API_KEY,
'proxy': settings.GPT_PROXY,
'model': settings.GPT_MODEL,
}
if settings.CHAT_AI_TYPE != ChatAITypeChoices.gpt:
data['url'] = settings.DEEPSEEK_BASE_URL
data['api_key'] = settings.DEEPSEEK_API_KEY
data['proxy'] = settings.DEEPSEEK_PROXY
data['model'] = settings.DEEPSEEK_MODEL
return data if not providers:
return {}
selected = next(
(p for p in providers if p.get('is_assistant')),
providers[0],
)
return {
'url': selected.get('base_url'),
'api_key': selected.get('api_key'),
'proxy': selected.get('proxy'),
'model': selected.get('model'),
'name': selected.get('name'),
}
def init_sqlite_db(): def init_sqlite_db():

View File

@@ -10,11 +10,12 @@ from common.utils import date_expired_default
__all__ = [ __all__ = [
'AnnouncementSettingSerializer', 'OpsSettingSerializer', 'VaultSettingSerializer', 'AnnouncementSettingSerializer', 'OpsSettingSerializer', 'VaultSettingSerializer',
'HashicorpKVSerializer', 'AzureKVSerializer', 'TicketSettingSerializer', 'HashicorpKVSerializer', 'AzureKVSerializer', 'TicketSettingSerializer',
'ChatAISettingSerializer', 'VirtualAppSerializer', 'AmazonSMSerializer', 'ChatAIProviderSerializer', 'ChatAISettingSerializer',
'VirtualAppSerializer', 'AmazonSMSerializer',
] ]
from settings.const import ( from settings.const import (
ChatAITypeChoices, GPTModelChoices, DeepSeekModelChoices, ChatAIMethodChoices ChatAITypeChoices, ChatAIMethodChoices
) )
@@ -120,6 +121,29 @@ class AmazonSMSerializer(serializers.Serializer):
) )
class ChatAIProviderListSerializer(serializers.ListSerializer):
# 标记整个列表需要加密存储,避免明文保存 API Key
encrypted = True
class ChatAIProviderSerializer(serializers.Serializer):
type = serializers.ChoiceField(
default=ChatAITypeChoices.openai, choices=ChatAITypeChoices.choices,
label=_("Types"), required=False,
)
base_url = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.')
)
api_key = EncryptedField(
allow_blank=True, required=False, label=_('API Key'),
)
proxy = serializers.CharField(
allow_blank=True, required=False, label=_('Proxy'),
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
)
class ChatAISettingSerializer(serializers.Serializer): class ChatAISettingSerializer(serializers.Serializer):
PREFIX_TITLE = _('Chat AI') PREFIX_TITLE = _('Chat AI')
@@ -130,44 +154,14 @@ class ChatAISettingSerializer(serializers.Serializer):
default=ChatAIMethodChoices.api, choices=ChatAIMethodChoices.choices, default=ChatAIMethodChoices.api, choices=ChatAIMethodChoices.choices,
label=_("Method"), required=False, label=_("Method"), required=False,
) )
CHAT_AI_PROVIDERS = ChatAIProviderListSerializer(
child=ChatAIProviderSerializer(),
allow_empty=True, required=False, default=list, label=_('Providers')
)
CHAT_AI_EMBED_URL = serializers.CharField( CHAT_AI_EMBED_URL = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'), allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.') help_text=_('The base URL of the Chat service.')
) )
CHAT_AI_TYPE = serializers.ChoiceField(
default=ChatAITypeChoices.gpt, choices=ChatAITypeChoices.choices,
label=_("Types"), required=False,
)
GPT_BASE_URL = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.')
)
GPT_API_KEY = EncryptedField(
allow_blank=True, required=False, label=_('API Key'),
)
GPT_PROXY = serializers.CharField(
allow_blank=True, required=False, label=_('Proxy'),
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
)
GPT_MODEL = serializers.ChoiceField(
default=GPTModelChoices.gpt_4o_mini, choices=GPTModelChoices.choices,
label=_("GPT Model"), required=False,
)
DEEPSEEK_BASE_URL = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.')
)
DEEPSEEK_API_KEY = EncryptedField(
allow_blank=True, required=False, label=_('API Key'),
)
DEEPSEEK_PROXY = serializers.CharField(
allow_blank=True, required=False, label=_('Proxy'),
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
)
DEEPSEEK_MODEL = serializers.ChoiceField(
default=DeepSeekModelChoices.deepseek_chat, choices=DeepSeekModelChoices.choices,
label=_("DeepSeek Model"), required=False,
)
class TicketSettingSerializer(serializers.Serializer): class TicketSettingSerializer(serializers.Serializer):

View File

@@ -73,8 +73,6 @@ class PrivateSettingSerializer(PublicSettingSerializer):
CHAT_AI_ENABLED = serializers.BooleanField() CHAT_AI_ENABLED = serializers.BooleanField()
CHAT_AI_METHOD = serializers.CharField() CHAT_AI_METHOD = serializers.CharField()
CHAT_AI_EMBED_URL = serializers.CharField() CHAT_AI_EMBED_URL = serializers.CharField()
CHAT_AI_TYPE = serializers.CharField()
GPT_MODEL = serializers.CharField()
FILE_UPLOAD_SIZE_LIMIT_MB = serializers.IntegerField() FILE_UPLOAD_SIZE_LIMIT_MB = serializers.IntegerField()
FTP_FILE_MAX_STORE = serializers.IntegerField() FTP_FILE_MAX_STORE = serializers.IntegerField()
LOKI_LOG_ENABLED = serializers.BooleanField() LOKI_LOG_ENABLED = serializers.BooleanField()

View File

@@ -21,7 +21,6 @@ urlpatterns = [
path('sms/<str:backend>/testing/', api.SMSTestingAPI.as_view(), name='sms-testing'), path('sms/<str:backend>/testing/', api.SMSTestingAPI.as_view(), name='sms-testing'),
path('sms/backend/', api.SMSBackendAPI.as_view(), name='sms-backend'), path('sms/backend/', api.SMSBackendAPI.as_view(), name='sms-backend'),
path('vault/<str:backend>/testing/', api.VaultTestingAPI.as_view(), name='vault-testing'), path('vault/<str:backend>/testing/', api.VaultTestingAPI.as_view(), name='vault-testing'),
path('chatai/testing/', api.ChatAITestingAPI.as_view(), name='chatai-testing'),
path('vault/sync/', api.VaultSyncDataAPI.as_view(), name='vault-sync'), path('vault/sync/', api.VaultSyncDataAPI.as_view(), name='vault-sync'),
path('security/block-ip/', api.BlockIPSecurityAPI.as_view(), name='block-ip'), path('security/block-ip/', api.BlockIPSecurityAPI.as_view(), name='block-ip'),
path('security/unlock-ip/', api.UnlockIPSecurityAPI.as_view(), name='unlock-ip'), path('security/unlock-ip/', api.UnlockIPSecurityAPI.as_view(), name='unlock-ip'),

View File

@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from .applet import * from .applet import *
from .chat import *
from .component import * from .component import *
from .session import * from .session import *
from .virtualapp import * from .virtualapp import *

View File

@@ -0,0 +1 @@
from .chat import *

View File

@@ -0,0 +1,15 @@
from common.api import JMSBulkModelViewSet
from terminal import serializers
from terminal.filters import ChatFilter
from terminal.models import Chat
__all__ = ['ChatViewSet']
class ChatViewSet(JMSBulkModelViewSet):
queryset = Chat.objects.all()
serializer_class = serializers.ChatSerializer
filterset_class = ChatFilter
search_fields = ['title']
ordering_fields = ['date_updated']
ordering = ['-date_updated']

View File

@@ -2,7 +2,7 @@ from django.db.models import QuerySet
from django_filters import rest_framework as filters from django_filters import rest_framework as filters
from orgs.utils import filter_org_queryset from orgs.utils import filter_org_queryset
from terminal.models import Command, CommandStorage, Session from terminal.models import Command, CommandStorage, Session, Chat
class CommandFilter(filters.FilterSet): class CommandFilter(filters.FilterSet):
@@ -79,7 +79,34 @@ class CommandStorageFilter(filters.FilterSet):
model = CommandStorage model = CommandStorage
fields = ['real', 'name', 'type', 'is_default'] fields = ['real', 'name', 'type', 'is_default']
def filter_real(self, queryset, name, value): @staticmethod
def filter_real(queryset, name, value):
if value: if value:
queryset = queryset.exclude(name='null') queryset = queryset.exclude(name='null')
return queryset return queryset
class ChatFilter(filters.FilterSet):
ids = filters.BooleanFilter(method='filter_ids')
folder_ids = filters.BooleanFilter(method='filter_folder_ids')
class Meta:
model = Chat
fields = [
'title', 'user_id', 'pinned', 'folder_id',
'archived', 'socket_id', 'share_id'
]
@staticmethod
def filter_ids(queryset, name, value):
ids = value.split(',')
queryset = queryset.filter(id__in=ids)
return queryset
@staticmethod
def filter_folder_ids(queryset, name, value):
ids = value.split(',')
queryset = queryset.filter(folder_id__in=ids)
return queryset

View File

@@ -0,0 +1,38 @@
# Generated by Django 4.1.13 on 2025-09-30 06:57
from django.db import migrations, models
import uuid
class Migration(migrations.Migration):
dependencies = [
('terminal', '0010_alter_command_risk_level_alter_session_login_from_and_more'),
]
operations = [
migrations.CreateModel(
name='Chat',
fields=[
('created_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Created by')),
('updated_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Updated by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('comment', models.TextField(blank=True, default='', verbose_name='Comment')),
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('title', models.CharField(max_length=256, verbose_name='Title')),
('chat', models.JSONField(default=dict, verbose_name='Chat')),
('meta', models.JSONField(default=dict, verbose_name='Meta')),
('pinned', models.BooleanField(default=False, verbose_name='Pinned')),
('archived', models.BooleanField(default=False, verbose_name='Archived')),
('share_id', models.CharField(blank=True, default='', max_length=36)),
('folder_id', models.CharField(blank=True, default='', max_length=36)),
('socket_id', models.CharField(blank=True, default='', max_length=36)),
('user_id', models.CharField(blank=True, db_index=True, default='', max_length=36)),
('session_info', models.JSONField(default=dict, verbose_name='Session Info')),
],
options={
'verbose_name': 'Chat',
},
),
]

View File

@@ -1,4 +1,5 @@
from .session import *
from .component import *
from .applet import * from .applet import *
from .chat import *
from .component import *
from .session import *
from .virtualapp import * from .virtualapp import *

View File

@@ -0,0 +1 @@
from .chat import *

View File

@@ -0,0 +1,30 @@
from django.db import models
from django.utils.translation import gettext_lazy as _
from common.db.models import JMSBaseModel
from common.utils import get_logger
logger = get_logger(__name__)
__all__ = ['Chat']
class Chat(JMSBaseModel):
# id == session_id # 36 chars
title = models.CharField(max_length=256, verbose_name=_('Title'))
chat = models.JSONField(default=dict, verbose_name=_('Chat'))
meta = models.JSONField(default=dict, verbose_name=_('Meta'))
pinned = models.BooleanField(default=False, verbose_name=_('Pinned'))
archived = models.BooleanField(default=False, verbose_name=_('Archived'))
share_id = models.CharField(blank=True, default='', max_length=36)
folder_id = models.CharField(blank=True, default='', max_length=36)
socket_id = models.CharField(blank=True, default='', max_length=36)
user_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
session_info = models.JSONField(default=dict, verbose_name=_('Session Info'))
class Meta:
verbose_name = _('Chat')
def __str__(self):
return self.title

View File

@@ -123,11 +123,10 @@ class Terminal(StorageMixin, TerminalStatusMixin, JMSBaseModel):
def get_chat_ai_setting(): def get_chat_ai_setting():
data = get_chatai_data() data = get_chatai_data()
return { return {
'GPT_BASE_URL': data['url'], 'GPT_BASE_URL': data.get('url'),
'GPT_API_KEY': data['api_key'], 'GPT_API_KEY': data.get('api_key'),
'GPT_PROXY': data['proxy'], 'GPT_PROXY': data.get('proxy'),
'GPT_MODEL': data['model'], 'CHAT_AI_PROVIDERS': settings.CHAT_AI_PROVIDERS,
'CHAT_AI_TYPE': settings.CHAT_AI_TYPE,
} }
@staticmethod @staticmethod

View File

@@ -2,8 +2,10 @@
# #
from .applet import * from .applet import *
from .applet_host import * from .applet_host import *
from .chat import *
from .command import * from .command import *
from .endpoint import * from .endpoint import *
from .loki import *
from .session import * from .session import *
from .sharing import * from .sharing import *
from .storage import * from .storage import *
@@ -11,4 +13,3 @@ from .task import *
from .terminal import * from .terminal import *
from .virtualapp import * from .virtualapp import *
from .virtualapp_provider import * from .virtualapp_provider import *
from .loki import *

View File

@@ -0,0 +1,28 @@
from rest_framework import serializers
from common.serializers import CommonBulkModelSerializer
from terminal.models import Chat
__all__ = ['ChatSerializer']
class ChatSerializer(CommonBulkModelSerializer):
created_at = serializers.SerializerMethodField()
updated_at = serializers.SerializerMethodField()
class Meta:
model = Chat
fields_mini = ['id', 'title', 'created_at', 'updated_at']
fields = fields_mini + [
'chat', 'meta', 'pinned', 'archived',
'share_id', 'folder_id',
'user_id', 'session_info'
]
@staticmethod
def get_created_at(obj):
return int(obj.date_created.timestamp())
@staticmethod
def get_updated_at(obj):
return int(obj.date_updated.timestamp())

View File

@@ -32,6 +32,7 @@ router.register(r'virtual-apps', api.VirtualAppViewSet, 'virtual-app')
router.register(r'app-providers', api.AppProviderViewSet, 'app-provider') router.register(r'app-providers', api.AppProviderViewSet, 'app-provider')
router.register(r'app-providers/((?P<provider>[^/.]+)/)?apps', api.AppProviderAppViewSet, 'app-provider-app') router.register(r'app-providers/((?P<provider>[^/.]+)/)?apps', api.AppProviderAppViewSet, 'app-provider-app')
router.register(r'virtual-app-publications', api.VirtualAppPublicationViewSet, 'virtual-app-publication') router.register(r'virtual-app-publications', api.VirtualAppPublicationViewSet, 'virtual-app-publication')
router.register(r'chats', api.ChatViewSet, 'chat')
urlpatterns = [ urlpatterns = [
path('my-sessions/', api.MySessionAPIView.as_view(), name='my-session'), path('my-sessions/', api.MySessionAPIView.as_view(), name='my-session'),

View File

@@ -199,11 +199,19 @@ class UserChangePasswordApi(UserQuerysetMixin, generics.UpdateAPIView):
class UserUnblockPKApi(UserQuerysetMixin, generics.UpdateAPIView): class UserUnblockPKApi(UserQuerysetMixin, generics.UpdateAPIView):
serializer_class = serializers.UserSerializer serializer_class = serializers.UserSerializer
def get_object(self):
pk = self.kwargs.get('pk')
if is_uuid(pk):
return super().get_object()
else:
return self.get_queryset().filter(username=pk).first()
def perform_update(self, serializer): def perform_update(self, serializer):
user = self.get_object() user = self.get_object()
username = user.username if user else '' if not user:
LoginBlockUtil.unblock_user(username) return Response({"error": _("User not found")}, status=404)
MFABlockUtils.unblock_user(username)
user.unblock_login()
class UserResetMFAApi(UserQuerysetMixin, generics.RetrieveAPIView): class UserResetMFAApi(UserQuerysetMixin, generics.RetrieveAPIView):

View File

@@ -274,8 +274,8 @@ class User(
LoginBlockUtil.unblock_user(self.username) LoginBlockUtil.unblock_user(self.username)
MFABlockUtils.unblock_user(self.username) MFABlockUtils.unblock_user(self.username)
@lazyproperty @property
def login_blocked(self): def is_login_blocked(self):
from users.utils import LoginBlockUtil, MFABlockUtils from users.utils import LoginBlockUtil, MFABlockUtils
if LoginBlockUtil.is_user_block(self.username): if LoginBlockUtil.is_user_block(self.username):
@@ -284,6 +284,13 @@ class User(
return True return True
return False return False
@classmethod
def block_login(cls, username):
from users.utils import LoginBlockUtil, MFABlockUtils
LoginBlockUtil.block_user(username)
MFABlockUtils.block_user(username)
def delete(self, using=None, keep_parents=False): def delete(self, using=None, keep_parents=False):
if self.pk == 1 or self.username == "admin": if self.pk == 1 or self.username == "admin":
raise PermissionDenied(_("Can not delete admin user")) raise PermissionDenied(_("Can not delete admin user"))

View File

@@ -123,7 +123,7 @@ class UserSerializer(
mfa_force_enabled = serializers.BooleanField( mfa_force_enabled = serializers.BooleanField(
read_only=True, label=_("MFA force enabled") read_only=True, label=_("MFA force enabled")
) )
login_blocked = serializers.BooleanField(read_only=True, label=_("Login blocked")) is_login_blocked = serializers.BooleanField(read_only=True, label=_("Login blocked"))
is_expired = serializers.BooleanField(read_only=True, label=_("Is expired")) is_expired = serializers.BooleanField(read_only=True, label=_("Is expired"))
is_valid = serializers.BooleanField(read_only=True, label=_("Is valid")) is_valid = serializers.BooleanField(read_only=True, label=_("Is valid"))
is_otp_secret_key_bound = serializers.BooleanField( is_otp_secret_key_bound = serializers.BooleanField(
@@ -193,6 +193,7 @@ class UserSerializer(
"is_valid", "is_expired", "is_active", # 布尔字段 "is_valid", "is_expired", "is_active", # 布尔字段
"is_otp_secret_key_bound", "can_public_key_auth", "is_otp_secret_key_bound", "can_public_key_auth",
"mfa_enabled", "need_update_password", "is_face_code_set", "mfa_enabled", "need_update_password", "is_face_code_set",
"is_login_blocked",
] ]
# 包含不太常用的字段,可以没有 # 包含不太常用的字段,可以没有
fields_verbose = ( fields_verbose = (
@@ -211,7 +212,7 @@ class UserSerializer(
# 多对多字段 # 多对多字段
fields_m2m = ["groups", "system_roles", "org_roles", "orgs_roles", "labels"] fields_m2m = ["groups", "system_roles", "org_roles", "orgs_roles", "labels"]
# 在serializer 上定义的字段 # 在serializer 上定义的字段
fields_custom = ["login_blocked", "password_strategy"] fields_custom = ["is_login_blocked", "password_strategy"]
fields = fields_verbose + fields_fk + fields_m2m + fields_custom fields = fields_verbose + fields_fk + fields_m2m + fields_custom
fields_unexport = ["avatar_url", "is_service_account"] fields_unexport = ["avatar_url", "is_service_account"]

View File

@@ -28,6 +28,6 @@ urlpatterns = [
path('users/<uuid:pk>/password/', api.UserChangePasswordApi.as_view(), name='change-user-password'), path('users/<uuid:pk>/password/', api.UserChangePasswordApi.as_view(), name='change-user-password'),
path('users/<uuid:pk>/password/reset/', api.UserResetPasswordApi.as_view(), name='user-reset-password'), path('users/<uuid:pk>/password/reset/', api.UserResetPasswordApi.as_view(), name='user-reset-password'),
path('users/<uuid:pk>/pubkey/reset/', api.UserResetPKApi.as_view(), name='user-public-key-reset'), path('users/<uuid:pk>/pubkey/reset/', api.UserResetPKApi.as_view(), name='user-public-key-reset'),
path('users/<uuid:pk>/unblock/', api.UserUnblockPKApi.as_view(), name='user-unblock'), path('users/<str:pk>/unblock/', api.UserUnblockPKApi.as_view(), name='user-unblock'),
] ]
urlpatterns += router.urls urlpatterns += router.urls

View File

@@ -186,6 +186,13 @@ class BlockUtilBase:
def is_block(self): def is_block(self):
return bool(cache.get(self.block_key)) return bool(cache.get(self.block_key))
@classmethod
def block_user(cls, username):
username = username.lower()
block_key = cls.BLOCK_KEY_TMPL.format(username)
key_ttl = int(settings.SECURITY_LOGIN_LIMIT_TIME) * 60
cache.set(block_key, True, key_ttl)
@classmethod @classmethod
def get_blocked_usernames(cls): def get_blocked_usernames(cls):
key = cls.BLOCK_KEY_TMPL.format('*') key = cls.BLOCK_KEY_TMPL.format('*')