diff --git a/frontend/src/components/common/account.js b/frontend/src/components/common/account.js index a8153f6f6f..ce030699f2 100644 --- a/frontend/src/components/common/account.js +++ b/frontend/src/components/common/account.js @@ -2,7 +2,7 @@ import React, { Component } from 'react'; import PropTypes from 'prop-types'; import { Utils } from '../../utils/utils'; import { seafileAPI } from '../../utils/seafile-api'; -import { siteRoot, isPro, gettext, appAvatarURL, enableSSOToThirdpartWebsite } from '../../utils/constants'; +import { siteRoot, isPro, gettext, appAvatarURL, enableSSOToThirdpartWebsite, enableSeafileAI } from '../../utils/constants'; import toaster from '../toast'; const { @@ -22,6 +22,9 @@ class Account extends Component { contactEmail: '', quotaUsage: '', quotaTotal: '', + aiCredit: '', + aiCost: '', + aiUsageRate: '', isStaff: false, isOrgStaff: false, usageRate: '', @@ -80,6 +83,9 @@ class Account extends Component { isOrgStaff: resp.data.is_org_staff === 1 ? true : false, showInfo: !this.state.showInfo, enableSubscription: resp.data.enable_subscription, + aiCredit: resp.data.ai_credit, + aiCost: resp.data.ai_cost, + aiUsageRate: resp.data.ai_usage_rate }); }).catch(error => { let errMessage = Utils.getErrorMsg(error); @@ -157,11 +163,20 @@ class Account extends Component {
{this.state.userName}
-
+

{gettext('Used:')}{' '}{this.state.quotaUsage} / {this.state.quotaTotal}

+ {enableSeafileAI && +
+
+

{gettext('AI credit used:')}{' '}{this.state.aiCost} / {this.state.aiCredit > 0 ? this.state.aiCredit : '--'}

+
+
+
+ } + {gettext('Settings')} {(this.state.enableSubscription && !isOrgContext) && {'付费管理'}} {this.renderMenu()} diff --git a/frontend/src/metadata/components/dialog/file-tags-dialog/index.js b/frontend/src/metadata/components/dialog/file-tags-dialog/index.js index b6de5ce5e7..c74a51578c 100644 --- a/frontend/src/metadata/components/dialog/file-tags-dialog/index.js +++ b/frontend/src/metadata/components/dialog/file-tags-dialog/index.js @@ -52,7 +52,11 @@ const FileTagsDialog = ({ record, onToggle, onSubmit }) => { setExitTags(exitTags); setLoading(false); }).catch(error => { - const errorMessage = gettext('Failed to generate file tags'); + let errorMessage = gettext('Failed to generate file tags'); + if (error.status === 429) { + const err_data = error.response.data; + errorMessage = gettext(err_data.error_msg); + } toaster.danger(errorMessage); setLoading(false); }); diff --git a/frontend/src/metadata/components/popover/ocr-result-popover/index.js b/frontend/src/metadata/components/popover/ocr-result-popover/index.js index eb21a2468c..d4f58f0fbb 100644 --- a/frontend/src/metadata/components/popover/ocr-result-popover/index.js +++ b/frontend/src/metadata/components/popover/ocr-result-popover/index.js @@ -130,7 +130,11 @@ const OCRResultPopover = ({ repoID, target, record, onToggle, saveToDescription setValue(value); setLoading(false); }).catch(error => { - const errorMessage = gettext('Failed to extract text'); + let errorMessage = gettext('Failed to extract text'); + if (error.status === 429) { + const err_data = error.response.data; + errorMessage = gettext(err_data.error_msg); + } setErrorMessage(errorMessage); setLoading(false); }); diff --git a/frontend/src/metadata/hooks/metadata-ai-operation.js b/frontend/src/metadata/hooks/metadata-ai-operation.js index 1d44b93633..0684cfe588 100644 --- a/frontend/src/metadata/hooks/metadata-ai-operation.js +++ b/frontend/src/metadata/hooks/metadata-ai-operation.js @@ -83,7 +83,11 @@ export const MetadataAIOperationsProvider = ({ success_callback && success_callback({ parentDir, fileName, description }); }).catch(error => { inProgressToaster.close(); - const errorMessage = gettext('Failed to generate description'); + let errorMessage = gettext('Failed to generate description'); + if (error.status === 429) { + const err_data = error.response.data; + errorMessage = gettext(err_data.error_msg); + } toaster.danger(errorMessage); fail_callback && fail_callback(); }); diff --git a/frontend/src/pages/org-admin/org-info.js b/frontend/src/pages/org-admin/org-info.js index a4bb1a193f..35330440bc 100644 --- a/frontend/src/pages/org-admin/org-info.js +++ b/frontend/src/pages/org-admin/org-info.js @@ -1,6 +1,6 @@ import React, { Component, Fragment } from 'react'; import { orgAdminAPI } from '../../utils/org-admin-api'; -import { mediaUrl, gettext, orgMemberQuotaEnabled } from '../../utils/constants'; +import { mediaUrl, gettext, orgMemberQuotaEnabled, enableSeafileAI } from '../../utils/constants'; import { Utils } from '../../utils/utils'; import MainPanelTopbar from './main-panel-topbar'; import '../../css/org-admin-info-page.css'; @@ -27,12 +27,12 @@ class OrgInfo extends Component { const { org_id, org_name, traffic_this_month, traffic_limit, member_quota, member_usage, active_members, - storage_quota, storage_usage + storage_quota, storage_usage, ai_cost, ai_credit } = res.data; this.setState({ org_id, org_name, traffic_this_month, traffic_limit, member_quota, member_usage, active_members, - storage_quota, storage_usage + storage_quota, storage_usage, ai_cost, ai_credit }); }); } @@ -41,7 +41,7 @@ class OrgInfo extends Component { const { org_id, org_name, traffic_this_month, traffic_limit, member_quota, member_usage, active_members, - storage_quota, storage_usage + storage_quota, storage_usage, ai_cost, ai_credit } = this.state; let download_traffic = traffic_this_month.link_file_download + traffic_this_month.sync_file_download + traffic_this_month.web_file_download; download_traffic = download_traffic ? download_traffic : 0; @@ -124,7 +124,21 @@ class OrgInfo extends Component {

{Utils.bytesToSize(download_traffic)}

)}
+ {enableSeafileAI && ( +
+

{gettext('AI credit used this month')}

+ <> +

{`${ai_credit > 0 ? (ai_cost / ai_credit * 100).toFixed(2) : '0'}%`}

+
+
+
+
+

{`${ai_cost} / ${ai_credit > 0 ? ai_credit : '--'}`}

+
+ +
+ )} diff --git a/seahub/ai/apis.py b/seahub/ai/apis.py index 52f8f0fc48..f001da7201 100644 --- a/seahub/ai/apis.py +++ b/seahub/ai/apis.py @@ -16,7 +16,7 @@ from seahub.api2.authentication import TokenAuthentication, SdocJWTTokenAuthenti from seahub.utils import get_file_type_and_ext, IMAGE from seahub.views import check_folder_permission from seahub.ai.utils import image_caption, translate, writing_assistant, verify_ai_config, generate_summary, \ - generate_file_tags, ocr + generate_file_tags, ocr, is_ai_usage_over_limit logger = logging.getLogger(__name__) @@ -33,7 +33,8 @@ class ImageCaption(APIView): repo_id = request.data.get('repo_id') path = request.data.get('path') lang = request.data.get('lang') - + org_id = request.user.org.org_id if request.user.org else None + username = request.user.username if not repo_id: return api_error(status.HTTP_400_BAD_REQUEST, 'repo_id invalid') if not path: @@ -50,6 +51,9 @@ class ImageCaption(APIView): error_msg = 'Library %s not found.' % repo_id return api_error(status.HTTP_404_NOT_FOUND, error_msg) + if is_ai_usage_over_limit(request.user, org_id): + return api_error(status.HTTP_429_TOO_MANY_REQUESTS, 'Credit not enough') + permission = check_folder_permission(request, repo_id, os.path.dirname(path)) if not permission: error_msg = 'Permission denied.' @@ -72,7 +76,9 @@ class ImageCaption(APIView): params = { 'path': path, 'download_token': token, - 'lang': lang + 'lang': lang, + 'org_id': org_id, + 'username': username } try: @@ -96,6 +102,8 @@ class GenerateSummary(APIView): repo_id = request.data.get('repo_id') path = request.data.get('path') + org_id = request.user.org.org_id if request.user.org else None + username = request.user.username if not repo_id: return api_error(status.HTTP_400_BAD_REQUEST, 'repo_id invalid') @@ -112,6 +120,9 @@ class GenerateSummary(APIView): error_msg = 'Permission denied.' return api_error(status.HTTP_403_FORBIDDEN, error_msg) + if is_ai_usage_over_limit(request.user, org_id): + return api_error(status.HTTP_429_TOO_MANY_REQUESTS, 'Credit not enough') + try: file_id = seafile_api.get_file_id_by_path(repo_id, path) except SearpcError as e: @@ -128,7 +139,9 @@ class GenerateSummary(APIView): params = { 'path': path, - 'download_token': token + 'download_token': token, + 'org_id': org_id, + 'username': username } try: @@ -152,6 +165,8 @@ class GenerateFileTags(APIView): repo_id = request.data.get('repo_id') path = request.data.get('path') + org_id = request.user.org.org_id if request.user.org else None + username = request.user.username if not repo_id: return api_error(status.HTTP_400_BAD_REQUEST, 'repo_id invalid') @@ -168,6 +183,9 @@ class GenerateFileTags(APIView): error_msg = 'Permission denied.' return api_error(status.HTTP_403_FORBIDDEN, error_msg) + if is_ai_usage_over_limit(request.user, org_id): + return api_error(status.HTTP_429_TOO_MANY_REQUESTS, 'Credit not enough') + try: file_id = seafile_api.get_file_id_by_path(repo_id, path) except SearpcError as e: @@ -185,6 +203,8 @@ class GenerateFileTags(APIView): params = { 'path': path, 'download_token': token, + 'org_id': org_id, + 'username': username } file_type, _ = get_file_type_and_ext(os.path.basename(path)) @@ -230,7 +250,8 @@ class OCR(APIView): repo_id = request.data.get('repo_id') path = request.data.get('path') - + org_id = request.user.org.org_id if request.user.org else None + username = request.user.username if not repo_id: return api_error(status.HTTP_400_BAD_REQUEST, 'repo_id invalid') if not path: @@ -250,6 +271,9 @@ class OCR(APIView): error_msg = 'Permission denied.' return api_error(status.HTTP_403_FORBIDDEN, error_msg) + if is_ai_usage_over_limit(request.user, org_id): + return api_error(status.HTTP_429_TOO_MANY_REQUESTS, 'Credit not enough') + try: file_id = seafile_api.get_file_id_by_path(repo_id, path) except SearpcError as e: @@ -272,6 +296,8 @@ class OCR(APIView): params = { 'file_name': os.path.basename(path), 'download_token': token, + 'org_id': org_id, + 'username': username } try: @@ -296,15 +322,22 @@ class Translate(APIView): text = request.data.get('text') lang = request.data.get('lang') + org_id = request.user.org.org_id if request.user.org else None + username = request.user.username if not text: return api_error(status.HTTP_400_BAD_REQUEST, 'text invalid') if not lang: return api_error(status.HTTP_400_BAD_REQUEST, 'lang invalid') + + if is_ai_usage_over_limit(request.user, org_id): + return api_error(status.HTTP_429_TOO_MANY_REQUESTS, 'Credit not enough') params = { 'text': text, 'lang': lang, + 'org_id': org_id, + 'username': username } try: @@ -329,16 +362,23 @@ class WritingAssistant(APIView): text = request.data.get('text') writing_type = request.data.get('writing_type') custom_prompt = request.data.get('custom_prompt') + org_id = request.user.org.org_id if request.user.org else None + username = request.user.username if not text: return api_error(status.HTTP_400_BAD_REQUEST, 'text invalid') if not custom_prompt and not writing_type: return api_error(status.HTTP_400_BAD_REQUEST, 'writing_type invalid') + if is_ai_usage_over_limit(request.user, org_id): + return api_error(status.HTTP_429_TOO_MANY_REQUESTS, 'Credit not enough') + params = { 'text': text, 'writing_type': writing_type, 'custom_prompt': custom_prompt, + 'org_id': org_id, + 'username': username } try: diff --git a/seahub/ai/models.py b/seahub/ai/models.py new file mode 100644 index 0000000000..4886fbcbdd --- /dev/null +++ b/seahub/ai/models.py @@ -0,0 +1,31 @@ +from django.db import models + + +class StatsAIByTeam(models.Model): + org_id = models.BigIntegerField(null=False) + month = models.DateField(null=False, db_index=True) + model = models.CharField(max_length=100, null=False) + input_tokens = models.IntegerField() + output_tokens = models.IntegerField() + cost = models.FloatField() + created_at = models.DateTimeField() + updated_at = models.DateTimeField() + + class Meta: + db_table = 'stats_ai_by_team' + unique_together = (('org_id', 'month', 'model'),) + + +class StatsAIByOwner(models.Model): + username = models.CharField(max_length=255, null=False) + month = models.DateField(null=False, db_index=True) + model = models.CharField(max_length=100, null=False) + input_tokens = models.IntegerField() + output_tokens = models.IntegerField() + cost = models.FloatField() + created_at = models.DateTimeField() + updated_at = models.DateTimeField() + + class Meta: + db_table = 'stats_ai_by_owner' + unique_together = (('username', 'month', 'model'),) diff --git a/seahub/ai/utils.py b/seahub/ai/utils.py index db2a434649..0d486b0459 100644 --- a/seahub/ai/utils.py +++ b/seahub/ai/utils.py @@ -4,11 +4,26 @@ import jwt import time from urllib.parse import urljoin +from django.utils import timezone +from django.db.models.functions import Coalesce +from django.db.models import Sum, Value + from seahub.settings import SEAFILE_AI_SECRET_KEY, SEAFILE_AI_SERVER_URL +from seahub.role_permissions.utils import get_enabled_role_permissions_by_role +from seahub.constants import DEFAULT_USER +from seahub.utils.user_permissions import get_user_role +from seahub.utils.ccnet_db import CcnetDB +from seahub.organizations.models import OrgMemberQuota +from seahub.ai.models import StatsAIByOwner, StatsAIByTeam +try: + from seahub.settings import ORG_MEMBER_QUOTA_ENABLED +except ImportError: + ORG_MEMBER_QUOTA_ENABLED = False logger = logging.getLogger(__name__) +# API def gen_headers(): payload = {'exp': int(time.time()) + 300, } token = jwt.encode(payload, SEAFILE_AI_SECRET_KEY, algorithm='HS256') @@ -63,8 +78,40 @@ def writing_assistant(params): return resp -def extract_text(params): - headers = gen_headers() - url = urljoin(SEAFILE_AI_SERVER_URL, '/api/v1/extract-text/') - resp = requests.post(url, json=params, headers=headers, timeout=30) - return resp +# utils +def get_ai_credit_by_user(user, org_id): + user_role = get_user_role(user) + role = DEFAULT_USER if (user_role == '' or user_role == DEFAULT_USER) else user_role + ai_credit_per_user = get_enabled_role_permissions_by_role(role)['monthly_ai_credit_per_user'] + if ai_credit_per_user < 0: + return -1 + if org_id and org_id != -1: + if ORG_MEMBER_QUOTA_ENABLED: + org_members_quota = OrgMemberQuota.objects.get_quota(org_id) + ai_credit = org_members_quota * ai_credit_per_user + else: + ccnet_db = CcnetDB() + user_count = ccnet_db.get_org_user_count(org_id) + ai_credit = user_count * ai_credit_per_user + else: + ai_credit = ai_credit_per_user + return ai_credit + + +def get_ai_cost_by_user(user, org_id): + month = timezone.now().replace(day=1) + if org_id and org_id > 0: + cost = StatsAIByTeam.objects.filter(org_id=org_id, month=month).aggregate(total_cost=Coalesce(Sum('cost'), Value(0.0)))['total_cost'] + else: + cost = StatsAIByOwner.objects.filter(username=user.username, month=month).aggregate(total_cost=Coalesce(Sum('cost'), Value(0.0)))['total_cost'] + return cost + + +def is_ai_usage_over_limit(user, org_id): + ai_credit = get_ai_credit_by_user(user, org_id) + cost = get_ai_cost_by_user(user, org_id) + + if ai_credit < 0: + return False + + return ai_credit <= round(cost, 2) diff --git a/seahub/api2/authentication.py b/seahub/api2/authentication.py index e0148dda53..81acbe9640 100644 --- a/seahub/api2/authentication.py +++ b/seahub/api2/authentication.py @@ -214,6 +214,11 @@ class SdocJWTTokenAuthentication(BaseAuthentication): user = None if not user or not user.is_active: return None + + if MULTI_TENANCY: + orgs = ccnet_api.get_orgs_by_user(username) + if orgs: + user.org = orgs[0] return user, auth[1] diff --git a/seahub/api2/views.py b/seahub/api2/views.py index 19f4ab0e1f..e6f9227d9c 100644 --- a/seahub/api2/views.py +++ b/seahub/api2/views.py @@ -106,7 +106,8 @@ from seahub.settings import THUMBNAIL_EXTENSION, THUMBNAIL_ROOT, \ STORAGE_CLASS_MAPPING_POLICY, \ ENABLE_RESET_ENCRYPTED_REPO_PASSWORD, SHARE_LINK_EXPIRE_DAYS_MAX, \ SHARE_LINK_EXPIRE_DAYS_MIN, SHARE_LINK_EXPIRE_DAYS_DEFAULT, \ - ENABLE_METADATA_FOR_NEW_LIBRARY, ENABLE_METADATA_MANAGEMENT + ENABLE_METADATA_FOR_NEW_LIBRARY, ENABLE_METADATA_MANAGEMENT, \ + ENABLE_SEAFILE_AI, SEAFILE_AI_SERVER_URL from seahub.subscription.utils import subscription_check from seahub.organizations.models import OrgAdminSettings, DISABLE_ORG_ENCRYPTED_LIBRARY from seahub.seadoc.utils import get_seadoc_file_uuid, gen_seadoc_image_parent_path, get_seadoc_asset_upload_link @@ -114,6 +115,7 @@ from seahub.views.file import get_office_feature_by_repo from seahub.repo_metadata.models import RepoMetadata, RepoMetadataViews from seahub.repo_metadata.utils import init_metadata, init_tags, add_init_metadata_task from seahub.repo_metadata.metadata_server_api import MetadataServerAPI +from seahub.ai.utils import get_ai_credit_by_user, get_ai_cost_by_user try: from seahub.settings import CLOUD_MODE @@ -311,17 +313,25 @@ class AccountInfo(APIView): email = request.user.username p = Profile.objects.get_profile_by_user(email) d_p = DetailedProfile.objects.get_detailed_profile_by_user(email) - + org_id = None if is_org_context(request): org_id = request.user.org.org_id quota_total = seafile_api.get_org_user_quota(org_id, email) quota_usage = seafile_api.get_org_user_quota_usage(org_id, email) is_org_staff = request.user.org.is_staff info['is_org_staff'] = is_org_staff + else: quota_total = seafile_api.get_user_quota(email) quota_usage = seafile_api.get_user_self_usage(email) + if ENABLE_SEAFILE_AI and SEAFILE_AI_SERVER_URL: + info['ai_credit'] = get_ai_credit_by_user(request.user, org_id) + info['ai_cost'] = round(get_ai_cost_by_user(request.user, org_id), 2) + info['ai_usage_rate'] = str(float(info['ai_cost']) / info['ai_credit'] * 100) + '%' + if info['ai_credit'] == -1: + info['ai_usage_rate'] = '0%' + if quota_total > 0: info['space_usage'] = str(float(quota_usage) / quota_total * 100) + '%' else: # no space quota set in config diff --git a/seahub/organizations/api/admin/info.py b/seahub/organizations/api/admin/info.py index cf488d4f40..a676b6ad1a 100644 --- a/seahub/organizations/api/admin/info.py +++ b/seahub/organizations/api/admin/info.py @@ -28,6 +28,9 @@ from seahub.organizations.models import OrgAdminSettings, \ from seahub.organizations.settings import ORG_MEMBER_QUOTA_ENABLED, \ ORG_ENABLE_ADMIN_CUSTOM_NAME +from django.conf import settings as dj_settings +from seahub.ai.utils import get_ai_cost_by_user, get_ai_credit_by_user + logger = logging.getLogger(__name__) @@ -100,6 +103,10 @@ def get_org_info(request, org_id): traffic_limit = get_quota_from_string(monthly_rate_limit_per_user) * member_quota info['traffic_limit'] = traffic_limit + if dj_settings.ENABLE_SEAFILE_AI and dj_settings.SEAFILE_AI_SERVER_URL: + info['ai_cost'] = round(get_ai_cost_by_user(request.user, org_id), 2) + info['ai_credit'] = get_ai_credit_by_user(request.user, org_id) + info['storage_quota'] = storage_quota info['storage_usage'] = storage_usage info['user_default_quota'] = user_default_quota diff --git a/seahub/role_permissions/settings.py b/seahub/role_permissions/settings.py index 0977370b0f..d535548299 100644 --- a/seahub/role_permissions/settings.py +++ b/seahub/role_permissions/settings.py @@ -50,6 +50,7 @@ DEFAULT_ENABLED_ROLE_PERMISSIONS = { 'monthly_rate_limit': '', 'monthly_rate_limit_per_user': '', 'can_choose_office_suite': True, + 'monthly_ai_credit_per_user': -1, }, GUEST_USER: { 'can_add_repo': False, diff --git a/seahub/settings.py b/seahub/settings.py index d7a65147b4..10473056ba 100644 --- a/seahub/settings.py +++ b/seahub/settings.py @@ -1086,6 +1086,7 @@ ENABLE_SEAFILE_AI = False SEAFILE_AI_SERVER_URL = '' SEAFILE_AI_SECRET_KEY = '' +AI_PRICES = {} d = os.path.dirname EVENTS_CONFIG_FILE = os.environ.get( diff --git a/sql/mysql.sql b/sql/mysql.sql index 289efec3cd..bd46b03fec 100644 --- a/sql/mysql.sql +++ b/sql/mysql.sql @@ -1654,3 +1654,33 @@ CREATE TABLE `notifications_sysusernotification` ( KEY `notifications_sysusernotification_seen_9d851bf7` (`seen`), KEY `notifications_sysusernotification_created_at_56ffd2a0` (`created_at`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci; + +CREATE TABLE `stats_ai_by_team` ( + `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, + `org_id` bigint(20) NOT NULL, + `month` date NOT NULL, + `model` varchar(100) NOT NULL, + `input_tokens` int(11) DEFAULT NULL, + `output_tokens` int(11) DEFAULT NULL, + `cost` double NOT NULL, + `created_at` datetime(6) DEFAULT NULL, + `updated_at` datetime(6) DEFAULT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `stats_ai_by_team_org_id_month_model` (`org_id`,`month`,`model`), + KEY `ix_stats_ai_by_team_org_id_month` (`org_id`,`month`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci; + +CREATE TABLE `stats_ai_by_owner` ( + `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, + `username` varchar(255) NOT NULL, + `month` date NOT NULL, + `model` varchar(100) NOT NULL, + `input_tokens` int(11) DEFAULT NULL, + `output_tokens` int(11) DEFAULT NULL, + `cost` double NOT NULL, + `created_at` datetime(6) DEFAULT NULL, + `updated_at` datetime(6) DEFAULT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `stats_ai_by_owner_username_month_model` (`username`,`month`,`model`), + KEY `ix_stats_ai_by_owner_username_month` (`username`,`month`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci; diff --git a/tests/seahub/role_permissions/test_utils.py b/tests/seahub/role_permissions/test_utils.py index 02b4a4a1d6..2c21539e99 100644 --- a/tests/seahub/role_permissions/test_utils.py +++ b/tests/seahub/role_permissions/test_utils.py @@ -11,4 +11,4 @@ class UtilsTest(BaseTestCase): assert DEFAULT_USER in get_available_roles() def test_get_enabled_role_permissions_by_role(self): - assert len(list(get_enabled_role_permissions_by_role(DEFAULT_USER).keys())) == 24 + assert len(list(get_enabled_role_permissions_by_role(DEFAULT_USER).keys())) == 25