diff --git a/requirements.txt b/requirements.txt index e875b4bfdf..900b91de74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ gunicorn==19.8.1 django-webpack-loader==0.6.0 git+git://github.com/haiwen/python-cas.git@ffc49235fd7cc32c4fdda5acfa3707e1405881df#egg=python_cas futures==3.2.0 +social-auth-core==1.7.0 diff --git a/seahub/notifications/management/commands/send_wxwork_notices.py b/seahub/notifications/management/commands/send_wxwork_notices.py new file mode 100644 index 0000000000..c6e6125086 --- /dev/null +++ b/seahub/notifications/management/commands/send_wxwork_notices.py @@ -0,0 +1,164 @@ +# Copyright (c) 2012-2016 Seafile Ltd. +# encoding: utf-8 +from datetime import datetime +import logging +import re + +from django.conf import settings +from django.core.management.base import BaseCommand +from django.core.urlresolvers import reverse +from django.utils import translation +from django.utils.translation import ungettext +from social_django.models import UserSocialAuth +from weworkapi import CorpApi + +from seahub.base.models import CommandsLastCheck +from seahub.notifications.models import UserNotification +from seahub.profile.models import Profile +from seahub.utils import get_site_scheme_and_netloc, get_site_name + +# Get an instance of a logger +logger = logging.getLogger(__name__) + +########## Utility Functions ########## +def wrap_div(s): + """ + Replace xx to xx and wrap content with
. + """ + patt = '(.+?)' + + def repl(matchobj): + return matchobj.group(1) + + return '
' + re.sub(patt, repl, s) + '
' + +class CommandLogMixin(object): + def println(self, msg): + self.stdout.write('[%s] %s\n' % (str(datetime.now()), msg)) + + def log_error(self, msg): + logger.error(msg) + self.println(msg) + + def log_info(self, msg): + logger.info(msg) + self.println(msg) + + def log_debug(self, msg): + logger.debug(msg) + self.println(msg) + +####################################### + +class Command(BaseCommand, CommandLogMixin): + help = 'Send WeChat Work msg to user if he/she has unseen notices every ' + 'period of time.' + label = "notifications_send_wxwork_notices" + + def handle(self, *args, **options): + self.log_debug('Start sending WeChat Work msg...') + self.api = CorpApi.CorpApi(settings.SOCIAL_AUTH_WEIXIN_WORK_KEY, + settings.SOCIAL_AUTH_WEIXIN_WORK_SECRET) + + self.do_action() + self.log_debug('Finish sending WeChat Work msg.\n') + + def send_wx_msg(self, uid, title, content, detail_url): + try: + self.log_info('Send wechat msg to user: %s, msg: %s' % (uid, content)) + response = self.api.httpCall( + CorpApi.CORP_API_TYPE['MESSAGE_SEND'], + { + "touser": uid, + "agentid": settings.SOCIAL_AUTH_WEIXIN_WORK_AGENTID, + 'msgtype': 'textcard', + # 'climsgid': 'climsgidclimsgid_d', + 'textcard': { + 'title': title, + 'description': content, + 'url': detail_url, + }, + 'safe': 0, + }) + self.log_info(response) + except Exception as ex: + logger.error(ex, exc_info=True) + + def get_user_language(self, username): + return Profile.objects.get_user_language(username) + + def do_action(self): + now = datetime.now() + today = datetime.now().replace(hour=0).replace(minute=0).replace( + second=0).replace(microsecond=0) + + # 1. get all users who are connected wechat work + socials = UserSocialAuth.objects.filter(provider='weixin-work') + users = [(x.username, x.uid) for x in socials] + if not users: + return + + user_uid_map = {} + for username, uid in users: + user_uid_map[username] = uid + + # 2. get previous time that command last runs + try: + cmd_last_check = CommandsLastCheck.objects.get(command_type=self.label) + self.log_debug('Last check time is %s' % cmd_last_check.last_check) + + last_check_dt = cmd_last_check.last_check + + cmd_last_check.last_check = now + cmd_last_check.save() + except CommandsLastCheck.DoesNotExist: + last_check_dt = today + self.log_debug('Create new last check time: %s' % now) + CommandsLastCheck(command_type=self.label, last_check=now).save() + + # 3. get all unseen notices for those users + qs = UserNotification.objects.filter( + timestamp__gt=last_check_dt + ).filter(seen=False).filter( + to_user__in=user_uid_map.keys() + ) + + user_notices = {} + for q in qs: + if q.to_user not in user_notices: + user_notices[q.to_user] = [q] + else: + user_notices[q.to_user].append(q) + + # 4. send msg to users + url = get_site_scheme_and_netloc().rstrip('/') + reverse('user_notification_list') + + for username, uid in users: + notices = user_notices.get(username, []) + count = len(notices) + if count == 0: + continue + + # save current language + cur_language = translation.get_language() + + # get and active user language + user_language = self.get_user_language(username) + translation.activate(user_language) + self.log_debug('Set language code to %s for user: %s' % ( + user_language, username)) + + title = ungettext( + "\n" + "You've got 1 new notice on %(site_name)s:\n", + "\n" + "You've got %(num)s new notices on %(site_name)s:\n", + count + ) % { + 'num': count, + 'site_name': get_site_name(), + } + content = ''.join([wrap_div(x.format_msg()) for x in notices]) + self.send_wx_msg(uid, title, content, url) + + translation.activate(cur_language) diff --git a/seahub/notifications/migrations/0003_auto_20181115_0825.py b/seahub/notifications/migrations/0003_auto_20181115_0825.py new file mode 100644 index 0000000000..8f7b676070 --- /dev/null +++ b/seahub/notifications/migrations/0003_auto_20181115_0825.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.15 on 2018-11-15 08:25 +from __future__ import unicode_literals + +import datetime +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('notifications', '0002_auto_20180426_0710'), + ] + + operations = [ + migrations.AlterField( + model_name='usernotification', + name='timestamp', + field=models.DateTimeField(db_index=True, default=datetime.datetime.now), + ), + ] diff --git a/seahub/notifications/models.py b/seahub/notifications/models.py index c6497c1958..e982af2515 100644 --- a/seahub/notifications/models.py +++ b/seahub/notifications/models.py @@ -23,7 +23,7 @@ from seahub.invitations.models import Invitation from seahub.utils.repo import get_repo_shared_users from seahub.utils import normalize_cache_key from seahub.utils.timeutils import datetime_to_isoformat_timestr -from seahub.constants import HASH_URLS +from seahub.constants import HASH_URLS # Get an instance of a logger logger = logging.getLogger(__name__) @@ -333,7 +333,7 @@ class UserNotification(models.Model): to_user = LowerCaseCharField(db_index=True, max_length=255) msg_type = models.CharField(db_index=True, max_length=30) detail = models.TextField() - timestamp = models.DateTimeField(default=datetime.datetime.now) + timestamp = models.DateTimeField(db_index=True, default=datetime.datetime.now) seen = models.BooleanField('seen', default=False) objects = UserNotificationManager() @@ -487,6 +487,32 @@ class UserNotification(models.Model): return {'message': message, 'msg_from': msg_from} ########## functions used in templates + def format_msg(self): + if self.is_group_msg(): + return self.format_group_message_title() + elif self.is_file_uploaded_msg(): + return self.format_file_uploaded_msg() + elif self.is_repo_share_msg(): + return self.format_repo_share_msg() + elif self.is_repo_share_to_group_msg(): + return self.format_repo_share_to_group_msg() + elif self.is_group_join_request(): + return self.format_group_join_request() + elif self.is_file_comment_msg(): + return self.format_file_comment_msg() + elif self.is_review_comment_msg(): + return self.format_review_comment_msg() + elif self.is_update_review_msg(): + return self.format_update_review_msg() + elif self.is_request_reviewer_msg(): + return self.format_request_reviewer_msg() + elif self.is_guest_invitation_accepted_msg(): + return self.format_guest_invitation_accepted_msg() + elif self.is_add_user_to_group(): + return self.format_add_user_to_group() + else: + return '' + def format_file_uploaded_msg(self): """ diff --git a/seahub/profile/templates/profile/set_profile.html b/seahub/profile/templates/profile/set_profile.html index fa665bac9d..526cf49834 100644 --- a/seahub/profile/templates/profile/set_profile.html +++ b/seahub/profile/templates/profile/set_profile.html @@ -175,6 +175,29 @@ {% endif %} +
+

{% trans "Social Login" %}

+ + + +
+ {% if ENABLE_DELETE_ACCOUNT %}

{% trans "Delete Account" %}

@@ -379,5 +402,11 @@ $('#set-email-notice-interval-form').on('submit', function() { }); return false; }); + +addConfirmTo($('.social-disconnect'), { + 'title':"{% trans "Disconnect" %}", + 'con':"{% trans "Are you sure you want to disconnect?" %}", + 'post':true +}); {% endblock %} diff --git a/seahub/profile/views.py b/seahub/profile/views.py index cd1b8cced0..ea3b163e42 100644 --- a/seahub/profile/views.py +++ b/seahub/profile/views.py @@ -86,6 +86,10 @@ def edit_profile(request): email_inverval = UserOptions.objects.get_file_updates_email_interval(username) email_inverval = email_inverval if email_inverval is not None else 0 + from social_django.models import UserSocialAuth + social_connected = UserSocialAuth.objects.filter( + username=request.user.username, provider='weixin-work').count() > 0 + resp_dict = { 'form': form, 'server_crypto': server_crypto, @@ -102,6 +106,8 @@ def edit_profile(request): 'ENABLE_UPDATE_USER_INFO': ENABLE_UPDATE_USER_INFO, 'webdav_passwd': webdav_passwd, 'email_notification_interval': email_inverval, + 'social_connected': social_connected, + 'social_next_page': reverse('edit_profile'), } if has_two_factor_auth(): diff --git a/seahub/settings.py b/seahub/settings.py index 736bb74ee0..4e513f00ad 100644 --- a/seahub/settings.py +++ b/seahub/settings.py @@ -125,6 +125,7 @@ MIDDLEWARE_CLASSES = ( 'seahub.two_factor.middleware.OTPMiddleware', 'seahub.two_factor.middleware.ForceTwoFactorAuthMiddleware', 'seahub.trusted_ip.middleware.LimitIpMiddleware', + 'social_django.middleware.SocialAuthExceptionMiddleware', ) @@ -152,6 +153,9 @@ TEMPLATES = [ 'django.template.context_processors.request', 'django.contrib.messages.context_processors.messages', + 'social_django.context_processors.backends', + 'social_django.context_processors.login_redirect', + 'seahub.auth.context_processors.auth', 'seahub.base.context_processors.base', 'seahub.base.context_processors.debug', @@ -223,6 +227,7 @@ INSTALLED_APPS = ( 'post_office', 'termsandconditions', 'webpack_loader', + 'social_django', 'seahub.api2', 'seahub.avatar', @@ -264,17 +269,42 @@ CONSTANCE_BACKEND = 'constance.backends.database.DatabaseBackend' CONSTANCE_DATABASE_CACHE_BACKEND = 'default' AUTHENTICATION_BACKENDS = ( + 'seahub.social_core.backends.weixin_enterprise.WeixinWorkOAuth2', + 'seahub.base.accounts.AuthBackend', 'seahub.oauth.backends.OauthRemoteUserBackend', ) +SOCIAL_AUTH_URL_NAMESPACE = 'social' +SOCIAL_AUTH_VERIFY_SSL = True +SOCIAL_AUTH_METHODS = ( + ('weixin', 'WeChat'), +) + +SOCIAL_AUTH_WEIXIN_WORK_AGENTID = '' +SOCIAL_AUTH_WEIXIN_WORK_KEY = '' +SOCIAL_AUTH_WEIXIN_WORK_SECRET = '' + +SOCIAL_AUTH_PIPELINE = ( + 'social_core.pipeline.social_auth.social_details', + 'social_core.pipeline.social_auth.social_uid', + 'social_core.pipeline.social_auth.auth_allowed', + 'seahub.social_core.pipeline.social_auth.social_user', + 'seahub.social_core.pipeline.user.get_username', + 'seahub.social_core.pipeline.user.create_user', + 'seahub.social_core.pipeline.social_auth.associate_user', + 'social_core.pipeline.social_auth.load_extra_data', + # 'social_core.pipeline.user.user_details', + 'seahub.social_core.pipeline.user.save_profile', +) + ENABLE_OAUTH = False ENABLE_WATERMARK = False # allow user to clean library trash ENABLE_USER_CLEAN_TRASH = True -LOGIN_REDIRECT_URL = '/profile/' +LOGIN_REDIRECT_URL = '/' LOGIN_URL = '/accounts/login/' LOGOUT_URL = '/accounts/logout/' LOGOUT_REDIRECT_URL = None diff --git a/seahub/social_core/__init__.py b/seahub/social_core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seahub/social_core/backends/__init__.py b/seahub/social_core/backends/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seahub/social_core/backends/weixin_enterprise.py b/seahub/social_core/backends/weixin_enterprise.py new file mode 100644 index 0000000000..53ea054aab --- /dev/null +++ b/seahub/social_core/backends/weixin_enterprise.py @@ -0,0 +1,198 @@ +import urllib +from requests import HTTPError + +from django.conf import settings + +from social_core.backends.oauth import BaseOAuth2 +from social_core.exceptions import AuthCanceled, AuthUnknownError + +import logging +logger = logging.getLogger(__name__) + +try: + WEIXIN_WORK_SP = True if settings.SOCIAL_AUTH_WEIXIN_WORK_SUITID else False +except AttributeError: + WEIXIN_WORK_SP = False + +if WEIXIN_WORK_SP is True: + _AUTHORIZATION_URL = 'https://open.work.weixin.qq.com/wwopen/sso/3rd_qrConnect' + _ACCESS_TOKEN_URL = 'https://qyapi.weixin.qq.com/cgi-bin/service/get_provider_token' + _USER_INFO_URL = 'https://qyapi.weixin.qq.com/cgi-bin/service/get_login_info' +else: + _AUTHORIZATION_URL = 'https://open.work.weixin.qq.com/wwopen/sso/qrConnect' + _ACCESS_TOKEN_URL = 'https://qyapi.weixin.qq.com/cgi-bin/token' + _USER_INFO_URL = 'https://qyapi.weixin.qq.com/cgi-bin/user/getuserinfo' + + +class WeixinWorkOAuth2(BaseOAuth2): + """WeChat Work OAuth authentication backend""" + name = 'weixin-work' + ID_KEY = 'UserId' + AUTHORIZATION_URL = _AUTHORIZATION_URL + ACCESS_TOKEN_URL = _ACCESS_TOKEN_URL + ACCESS_TOKEN_METHOD = 'POST' + DEFAULT_SCOPE = ['snsapi_login'] + REDIRECT_STATE = False + EXTRA_DATA = [ + ('nickname', 'username'), + ('headimgurl', 'profile_image_url'), + ] + + def extra_data(self, user, uid, response, details=None, *args, **kwargs): + data = super(BaseOAuth2, self).extra_data(user, uid, response, + details=details, + *args, **kwargs) + + if WEIXIN_WORK_SP: + data['corp_info'] = response.get('corp_info') + data['user_info'] = response.get('user_info') + + return data + + def get_user_id(self, details, response): + """Return a unique ID for the current user, by default from server + response.""" + if WEIXIN_WORK_SP: + return response.get('user_info').get('userid') + else: + return response.get(self.ID_KEY) + + def get_user_details(self, response): + """Return user details from Weixin. API URL is: + https://api.weixin.qq.com/sns/userinfo + """ + if WEIXIN_WORK_SP: + user_info = response.get('user_info') + return { + 'userid': user_info.get('userid'), + 'user_name': user_info.get('name'), + 'user_avatar': user_info.get('avatar'), + 'corpid': response.get('corp_info').get('corpid'), + } + else: + if self.setting('DOMAIN_AS_USERNAME'): + username = response.get('domain', '') + else: + username = response.get('nickname', '') + return { + 'username': username, + 'profile_image_url': response.get('headimgurl', '') + } + + def user_data(self, access_token, *args, **kwargs): + if WEIXIN_WORK_SP: + data = self.get_json(_USER_INFO_URL, + params={'access_token': access_token}, + json={'auth_code': kwargs['request'].GET.get('auth_code')}, + headers={'Content-Type': 'application/json', + 'Accept': 'application/json'}, + method='post') + + else: + data = self.get_json(_USER_INFO_URL, params={ + 'access_token': access_token, + 'code': kwargs['request'].GET.get('code') + }) + + nickname = data.get('nickname') + if nickname: + # weixin api has some encode bug, here need handle + data['nickname'] = nickname.encode( + 'raw_unicode_escape' + ).decode('utf-8') + + return data + + def auth_params(self, state=None): + appid, secret = self.get_key_and_secret() + + if WEIXIN_WORK_SP: + params = { + 'appid': appid, + 'redirect_uri': self.get_redirect_uri(state), + 'usertype': 'member', + } + else: + params = { + 'appid': appid, + 'redirect_uri': self.get_redirect_uri(state), + 'agentid': self.setting('AGENTID'), + } + + if self.STATE_PARAMETER and state: + params['state'] = state + if self.RESPONSE_TYPE: + params['response_type'] = self.RESPONSE_TYPE + return params + + def auth_complete_params(self, state=None): + appid, secret = self.get_key_and_secret() + if WEIXIN_WORK_SP is True: + return { + 'corpid': appid, + 'provider_secret': secret, + } + + return { + 'grant_type': 'authorization_code', # request auth code + 'code': self.data.get('code', ''), # server response code + 'appid': appid, + 'secret': secret, + 'redirect_uri': self.get_redirect_uri(state), + } + + def refresh_token_params(self, token, *args, **kwargs): + appid, secret = self.get_key_and_secret() + return { + 'refresh_token': token, + 'grant_type': 'refresh_token', + 'appid': appid, + 'secret': secret + } + + def access_token_url(self, appid, secret): + if WEIXIN_WORK_SP: + return self.ACCESS_TOKEN_URL + else: + return self.ACCESS_TOKEN_URL + '?corpid=%s&corpsecret=%s' % (appid, secret) + + def auth_complete(self, *args, **kwargs): + """Completes loging process, must return user instance""" + self.process_error(self.data) + + appid, secret = self.get_key_and_secret() + try: + if WEIXIN_WORK_SP: + response = self.request_access_token( + self.access_token_url(appid, secret), + json=self.auth_complete_params(self.validate_state()), + headers={'Content-Type': 'application/json', + 'Accept': 'application/json'}, + method=self.ACCESS_TOKEN_METHOD + ) + else: + response = self.request_access_token( + self.access_token_url(appid, secret), + data=self.auth_complete_params(self.validate_state()), + headers=self.auth_headers(), + method=self.ACCESS_TOKEN_METHOD + ) + except HTTPError as err: + if err.response.status_code == 400: + raise AuthCanceled(self, response=err.response) + else: + raise + except KeyError: + raise AuthUnknownError(self) + + try: + if response['errmsg'] != 'ok': + raise AuthCanceled(self) + except KeyError: + pass # assume response is ok if 'errmsg' key not found + + self.process_error(response) + + access_token = response['provider_access_token'] if WEIXIN_WORK_SP else response['access_token'] + return self.do_auth(access_token, response=response, + *args, **kwargs) diff --git a/seahub/social_core/pipeline/__init__.py b/seahub/social_core/pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seahub/social_core/pipeline/social_auth.py b/seahub/social_core/pipeline/social_auth.py new file mode 100644 index 0000000000..efc9fed943 --- /dev/null +++ b/seahub/social_core/pipeline/social_auth.py @@ -0,0 +1,34 @@ +from social_core.exceptions import AuthAlreadyAssociated + +def social_user(backend, uid, user=None, *args, **kwargs): + provider = backend.name + social = backend.strategy.storage.user.get_social_auth(provider, uid) + if social: + if user and social.user.username != user.username: + msg = 'This {0} account is already in use.'.format(provider) + raise AuthAlreadyAssociated(backend, msg) + elif not user: + user = social.user + return {'social': social, + 'user': user, + 'is_new': user is None, + 'new_association': social is None} + + +def associate_user(backend, uid, user=None, social=None, *args, **kwargs): + if user and not social: + try: + social = backend.strategy.storage.user.create_social_auth( + user, uid, backend.name + ) + except Exception as err: + if not backend.strategy.storage.is_integrity_error(err): + raise + # Protect for possible race condition, those bastard with FTL + # clicking capabilities, check issue #131: + # https://github.com/omab/django-social-auth/issues/131 + return social_user(backend, uid, user, *args, **kwargs) + else: + return {'social': social, + 'user': user, + 'new_association': True} diff --git a/seahub/social_core/pipeline/user.py b/seahub/social_core/pipeline/user.py new file mode 100644 index 0000000000..f25cf06fc3 --- /dev/null +++ b/seahub/social_core/pipeline/user.py @@ -0,0 +1,87 @@ +from seahub.profile.models import Profile +from seahub.utils.auth import gen_user_virtual_id + +USER_FIELDS = ['username', 'email'] + + +def get_username(strategy, details, backend, user=None, *args, **kwargs): + if 'username' not in backend.setting('USER_FIELDS', USER_FIELDS): + return + storage = strategy.storage + + if not user: + final_username = gen_user_virtual_id() + else: + final_username = storage.user.get_username(user) + + return {'username': final_username} + + +def create_user(strategy, details, backend, user=None, *args, **kwargs): + if user: + return {'is_new': False} + + fields = dict((name, kwargs.get(name, details.get(name))) + for name in backend.setting('USER_FIELDS', USER_FIELDS)) + if not fields: + return + + return { + 'is_new': True, + 'user': strategy.create_user(**fields) + } + + + +def save_profile(strategy, details, backend, user=None, *args, **kwargs): + if not user: + return + email = details.get('email', '') + if email: + Profile.objects.add_or_update(username=user.username, + contact_email=email) + + fullname = details.get('fullname', '') + if fullname: + Profile.objects.add_or_update(username=user.username, + nickname=fullname) + + # weixin username and profile_image_url + nickname = details.get('username', '') + if nickname: + Profile.objects.add_or_update(username=user.username, + nickname=nickname) + + avatar_url = details.get('profile_image_url', '') + if avatar_url: + _update_user_avatar(user, avatar_url) + +import os +import logging +import urllib2 +from django.core.files import File +from seahub.avatar.models import Avatar +from seahub.avatar.signals import avatar_updated +logger = logging.getLogger(__name__) + +def _update_user_avatar(user, pic): + if not pic: + return + + logger.info("retrieve pic from %s" % pic) + + filedata = urllib2.urlopen(pic) + datatowrite = filedata.read() + filename = '/tmp/%s.jpg' % user.username + with open(filename, 'wb') as f: + f.write(datatowrite) + + logger.info("save pic to %s" % filename) + avatar = Avatar(emailuser=user.username, primary=True) + avatar.avatar.save( + 'image.jpg', File(open(filename)) + ) + avatar.save() + avatar_updated.send(sender=Avatar, user=user, avatar=avatar) + + os.remove(filename) diff --git a/seahub/social_core/utils/WXBizMsgCrypt.py b/seahub/social_core/utils/WXBizMsgCrypt.py new file mode 100644 index 0000000000..3b515d3430 --- /dev/null +++ b/seahub/social_core/utils/WXBizMsgCrypt.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python +#-*- encoding:utf-8 -*- + +""" 对企业微信发送给企业后台的消息加解密示例代码. +@copyright: Copyright (c) 1998-2014 Tencent Inc. + +""" +# ------------------------------------------------------------------------ + +import base64 +import string +import random +import hashlib +import time +import struct +try: + from Crypto.Cipher import AES +except ImportError: + AES = None +import xml.etree.cElementTree as ET +import sys +import socket +reload(sys) +from . import ierror +sys.setdefaultencoding('utf-8') + +""" +关于Crypto.Cipher模块,ImportError: No module named 'Crypto'解决方案 +请到官方网站 https://www.dlitz.net/software/pycrypto/ 下载pycrypto。 +下载后,按照README中的“Installation”小节的提示进行pycrypto安装。 +""" +class FormatException(Exception): + pass + +def throw_exception(message, exception_class=FormatException): + """my define raise exception function""" + raise exception_class(message) + +class SHA1: + """计算企业微信的消息签名接口""" + + def getSHA1(self, token, timestamp, nonce, encrypt): + """用SHA1算法生成安全签名 + @param token: 票据 + @param timestamp: 时间戳 + @param encrypt: 密文 + @param nonce: 随机字符串 + @return: 安全签名 + """ + try: + sortlist = [token, timestamp, nonce, encrypt] + sortlist.sort() + sha = hashlib.sha1() + sha.update("".join(sortlist)) + return ierror.WXBizMsgCrypt_OK, sha.hexdigest() + except Exception,e: + print e + return ierror.WXBizMsgCrypt_ComputeSignature_Error, None + + +class XMLParse: + """提供提取消息格式中的密文及生成回复消息格式的接口""" + + # xml消息模板 + AES_TEXT_RESPONSE_TEMPLATE = """ + + +%(timestamp)s + +""" + + def extract(self, xmltext): + """提取出xml数据包中的加密消息 + @param xmltext: 待提取的xml字符串 + @return: 提取出的加密消息字符串 + """ + try: + xml_tree = ET.fromstring(xmltext) + encrypt = xml_tree.find("Encrypt") + return ierror.WXBizMsgCrypt_OK, encrypt.text + except Exception,e: + print e + return ierror.WXBizMsgCrypt_ParseXml_Error,None,None + + def generate(self, encrypt, signature, timestamp, nonce): + """生成xml消息 + @param encrypt: 加密后的消息密文 + @param signature: 安全签名 + @param timestamp: 时间戳 + @param nonce: 随机字符串 + @return: 生成的xml字符串 + """ + resp_dict = { + 'msg_encrypt' : encrypt, + 'msg_signaturet': signature, + 'timestamp' : timestamp, + 'nonce' : nonce, + } + resp_xml = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict + return resp_xml + + +class PKCS7Encoder(): + """提供基于PKCS7算法的加解密接口""" + + block_size = 32 + def encode(self, text): + """ 对需要加密的明文进行填充补位 + @param text: 需要进行填充补位操作的明文 + @return: 补齐明文字符串 + """ + text_length = len(text) + # 计算需要填充的位数 + amount_to_pad = self.block_size - (text_length % self.block_size) + if amount_to_pad == 0: + amount_to_pad = self.block_size + # 获得补位所用的字符 + pad = chr(amount_to_pad) + return text + pad * amount_to_pad + + def decode(self, decrypted): + """删除解密后明文的补位字符 + @param decrypted: 解密后的明文 + @return: 删除补位字符后的明文 + """ + pad = ord(decrypted[-1]) + if pad<1 or pad >32: + pad = 0 + return decrypted[:-pad] + + +class Prpcrypt(object): + """提供接收和推送给企业微信消息的加解密接口""" + + def __init__(self,key): + + #self.key = base64.b64decode(key+"=") + self.key = key + # 设置加解密模式为AES的CBC模式 + self.mode = AES.MODE_CBC + + + def encrypt(self,text,receiveid): + """对明文进行加密 + @param text: 需要加密的明文 + @return: 加密得到的字符串 + """ + # 16位随机字符串添加到明文开头 + text = self.get_random_str() + struct.pack("I",socket.htonl(len(text))) + text + receiveid + # 使用自定义的填充方式对明文进行补位填充 + pkcs7 = PKCS7Encoder() + text = pkcs7.encode(text) + # 加密 + cryptor = AES.new(self.key,self.mode,self.key[:16]) + try: + ciphertext = cryptor.encrypt(text) + # 使用BASE64对加密后的字符串进行编码 + return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext) + except Exception,e: + print e + return ierror.WXBizMsgCrypt_EncryptAES_Error,None + + def decrypt(self,text,receiveid): + """对解密后的明文进行补位删除 + @param text: 密文 + @return: 删除填充补位后的明文 + """ + try: + cryptor = AES.new(self.key,self.mode,self.key[:16]) + # 使用BASE64对密文进行解码,然后AES-CBC解密 + plain_text = cryptor.decrypt(base64.b64decode(text)) + except Exception,e: + print e + return ierror.WXBizMsgCrypt_DecryptAES_Error,None + try: + pad = ord(plain_text[-1]) + # 去掉补位字符串 + #pkcs7 = PKCS7Encoder() + #plain_text = pkcs7.encode(plain_text) + # 去除16位随机字符串 + content = plain_text[16:-pad] + xml_len = socket.ntohl(struct.unpack("I",content[ : 4])[0]) + xml_content = content[4 : xml_len+4] + from_receiveid = content[xml_len+4:] + except Exception,e: + print e + return ierror.WXBizMsgCrypt_IllegalBuffer,None + if from_receiveid != receiveid: + return ierror.WXBizMsgCrypt_ValidateCorpid_Error,None + return 0,xml_content + + def get_random_str(self): + """ 随机生成16位字符串 + @return: 16位字符串 + """ + rule = string.letters + string.digits + str = random.sample(rule, 16) + return "".join(str) + +class WXBizMsgCrypt(object): + #构造函数 + def __init__(self,sToken,sEncodingAESKey,sReceiveId): + try: + self.key = base64.b64decode(sEncodingAESKey+"=") + assert len(self.key) == 32 + except: + throw_exception("[error]: EncodingAESKey unvalid !", FormatException) + # return ierror.WXBizMsgCrypt_IllegalAesKey,None + self.m_sToken = sToken + self.m_sReceiveId = sReceiveId + + #验证URL + #@param sMsgSignature: 签名串,对应URL参数的msg_signature + #@param sTimeStamp: 时间戳,对应URL参数的timestamp + #@param sNonce: 随机串,对应URL参数的nonce + #@param sEchoStr: 随机串,对应URL参数的echostr + #@param sReplyEchoStr: 解密之后的echostr,当return返回0时有效 + #@return:成功0,失败返回对应的错误码 + + def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr): + sha1 = SHA1() + ret,signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr) + if ret != 0: + return ret, None + if not signature == sMsgSignature: + return ierror.WXBizMsgCrypt_ValidateSignature_Error, None + pc = Prpcrypt(self.key) + ret,sReplyEchoStr = pc.decrypt(sEchoStr,self.m_sReceiveId) + return ret,sReplyEchoStr + + def EncryptMsg(self, sReplyMsg, sNonce, timestamp = None): + #将企业回复用户的消息加密打包 + #@param sReplyMsg: 企业号待回复用户的消息,xml格式的字符串 + #@param sTimeStamp: 时间戳,可以自己生成,也可以用URL参数的timestamp,如为None则自动用当前时间 + #@param sNonce: 随机串,可以自己生成,也可以用URL参数的nonce + #sEncryptMsg: 加密后的可以直接回复用户的密文,包括msg_signature, timestamp, nonce, encrypt的xml格式的字符串, + #return:成功0,sEncryptMsg,失败返回对应的错误码None + pc = Prpcrypt(self.key) + ret,encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId) + if ret != 0: + return ret,None + if timestamp is None: + timestamp = str(int(time.time())) + # 生成安全签名 + sha1 = SHA1() + ret,signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt) + if ret != 0: + return ret,None + xmlParse = XMLParse() + return ret,xmlParse.generate(encrypt, signature, timestamp, sNonce) + + def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce): + # 检验消息的真实性,并且获取解密后的明文 + # @param sMsgSignature: 签名串,对应URL参数的msg_signature + # @param sTimeStamp: 时间戳,对应URL参数的timestamp + # @param sNonce: 随机串,对应URL参数的nonce + # @param sPostData: 密文,对应POST请求的数据 + # xml_content: 解密后的原文,当return返回0时有效 + # @return: 成功0,失败返回对应的错误码 + # 验证安全签名 + xmlParse = XMLParse() + ret,encrypt = xmlParse.extract(sPostData) + if ret != 0: + return ret, None + sha1 = SHA1() + ret,signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt) + if ret != 0: + return ret, None + if not signature == sMsgSignature: + return ierror.WXBizMsgCrypt_ValidateSignature_Error, None + pc = Prpcrypt(self.key) + ret,xml_content = pc.decrypt(encrypt,self.m_sReceiveId) + return ret,xml_content + + diff --git a/seahub/social_core/utils/__init__.py b/seahub/social_core/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seahub/social_core/utils/ierror.py b/seahub/social_core/utils/ierror.py new file mode 100644 index 0000000000..6678fecfd6 --- /dev/null +++ b/seahub/social_core/utils/ierror.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +######################################################################### +# Author: jonyqin +# Created Time: Thu 11 Sep 2014 01:53:58 PM CST +# File Name: ierror.py +# Description:定义错误码含义 +######################################################################### +WXBizMsgCrypt_OK = 0 +WXBizMsgCrypt_ValidateSignature_Error = -40001 +WXBizMsgCrypt_ParseXml_Error = -40002 +WXBizMsgCrypt_ComputeSignature_Error = -40003 +WXBizMsgCrypt_IllegalAesKey = -40004 +WXBizMsgCrypt_ValidateCorpid_Error = -40005 +WXBizMsgCrypt_EncryptAES_Error = -40006 +WXBizMsgCrypt_DecryptAES_Error = -40007 +WXBizMsgCrypt_IllegalBuffer = -40008 +WXBizMsgCrypt_EncodeBase64_Error = -40009 +WXBizMsgCrypt_DecodeBase64_Error = -40010 +WXBizMsgCrypt_GenReturnXml_Error = -40011 diff --git a/seahub/social_core/views.py b/seahub/social_core/views.py new file mode 100644 index 0000000000..18190ea15a --- /dev/null +++ b/seahub/social_core/views.py @@ -0,0 +1,145 @@ +import logging + +import requests +from django.conf import settings +from django.core.cache import cache +from django.core.urlresolvers import reverse +from django.http import HttpResponse, HttpResponseRedirect +from django.views.decorators.csrf import csrf_exempt +from django.utils.http import urlquote + +from seahub.social_core.utils.WXBizMsgCrypt import WXBizMsgCrypt +from seahub.utils.urls import abs_reverse + +# Get an instance of a logger +logger = logging.getLogger(__name__) + +@csrf_exempt +def weixin_work_cb(request): + """Callback for weixin work provider API. + + Used in callback config at app details page. + e.g. https://open.work.weixin.qq.com/wwopen/developer#/sass/apps/detail/ww24c53566499d354f + + ref: https://work.weixin.qq.com/api/doc#90001/90143/91116 + """ + + token = settings.SOCIAL_AUTH_WEIXIN_WORK_TOKEN + EncodingAESKey = settings.SOCIAL_AUTH_WEIXIN_WORK_AES_KEY + + msg_signature = request.GET.get('msg_signature', None) + timestamp = request.GET.get('timestamp', None) + nonce = request.GET.get('nonce', None) + if not (msg_signature and timestamp and nonce): + assert False, 'Request Error' + + if request.method == 'GET': + wxcpt = WXBizMsgCrypt(token, EncodingAESKey, + settings.SOCIAL_AUTH_WEIXIN_WORK_KEY) + + echostr = request.GET.get('echostr', '') + ret, decoded_echostr = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr) + if ret != 0: + assert False, 'Verify Error' + + return HttpResponse(decoded_echostr) + + elif request.method == 'POST': + wxcpt = WXBizMsgCrypt(token, EncodingAESKey, + settings.SOCIAL_AUTH_WEIXIN_WORK_SUITID) + + ret, xml_msg = wxcpt.DecryptMsg(request.body, msg_signature, timestamp, nonce) + if ret != 0: + assert False, 'Decrypt Error' + + import xml.etree.cElementTree as ET + xml_tree = ET.fromstring(xml_msg) + suite_ticket = xml_tree.find("SuiteTicket").text + logger.info('suite ticket: %s' % suite_ticket) + + # TODO: use persistent store + cache.set('wx_work_suite_ticket', suite_ticket, 3600) + + return HttpResponse('success') + +def _get_suite_access_token(): + suite_access_token = cache.get('wx_work_suite_access_token', None) + if suite_access_token: + return suite_access_token + + suite_ticket = cache.get('wx_work_suite_ticket', None) + if not suite_ticket: + assert False, 'suite ticket is None!' + + get_suite_token_url = 'https://qyapi.weixin.qq.com/cgi-bin/service/get_suite_token' + resp = requests.request( + 'POST', get_suite_token_url, + json={ + "suite_id": settings.SOCIAL_AUTH_WEIXIN_WORK_SUITID, + "suite_secret": settings.SOCIAL_AUTH_WEIXIN_WORK_SUIT_SECRET, + "suite_ticket": suite_ticket, + }, + headers={'Content-Type': 'application/json', + 'Accept': 'application/json'}, + ) + + suite_access_token = resp.json().get('suite_access_token', None) + if not suite_access_token: + logger.error('Failed to get suite_access_token!') + logger.error(resp.content) + assert False, 'suite_access_token is None!' + else: + cache.set('wx_work_suite_access_token', suite_access_token, 3600) + return suite_access_token + +def weixin_work_3rd_app_install(request): + """Redirect user to weixin work 3rd app install page. + """ + # 0. get suite access token + suite_access_token = _get_suite_access_token() + print('suite access token', suite_access_token) + + # 1. get pre_auth_code + get_pre_auth_code_url = 'https://qyapi.weixin.qq.com/cgi-bin/service/get_pre_auth_code?suite_access_token=' + suite_access_token + resp = requests.request('GET', get_pre_auth_code_url) + + pre_auth_code = resp.json().get('pre_auth_code', None) + if not pre_auth_code: + logger.error('Failed to get pre_auth_code') + logger.error(resp.content) + assert False, 'pre_auth_code is None' + + # 2. set session info + # ref: https://work.weixin.qq.com/api/doc#90001/90143/90602 + url = 'https://qyapi.weixin.qq.com/cgi-bin/service/set_session_info?suite_access_token=' + suite_access_token + resp = requests.request( + 'POST', url, + json={ + "pre_auth_code": pre_auth_code, + "session_info": + { + "appid": [], + "auth_type": 1 # TODO: 0: production; 1: testing. + } + }, + headers={'Content-Type': 'application/json', + 'Accept': 'application/json'}, + + ) + + # TODO: use random state + url = 'https://open.work.weixin.qq.com/3rdapp/install?suite_id=%s&pre_auth_code=%s&redirect_uri=%s&state=STATE123' % ( + settings.SOCIAL_AUTH_WEIXIN_WORK_SUITID, + pre_auth_code, + abs_reverse('weixin_work_3rd_app_install_cb'), + ) + return HttpResponseRedirect(url) + +@csrf_exempt +def weixin_work_3rd_app_install_cb(request): + """Callback for weixin work 3rd app install API. + + https://work.weixin.qq.com/api/doc#90001/90143/90597 + """ + # TODO: check state + pass diff --git a/seahub/urls.py b/seahub/urls.py index 80fbd4c875..955ee5354f 100644 --- a/seahub/urls.py +++ b/seahub/urls.py @@ -137,6 +137,7 @@ urlpatterns = [ url(r'^sso/$', sso, name='sso'), url(r'^shib-login/', shib_login, name="shib_login"), url(r'^oauth/', include('seahub.oauth.urls')), + url(r'^social/', include('social_django.urls', namespace='social')), url(r'^$', libraries, name='libraries'), #url(r'^home/$', direct_to_template, { 'template': 'home.html' } ), @@ -693,3 +694,15 @@ if getattr(settings, 'ENABLE_CAS', False): url(r'^accounts/cas-logout/$', cas_logout, name='cas_ng_logout'), url(r'^accounts/cas-callback/$', cas_callback, name='cas_ng_proxy_callback'), ] + + +from seahub.social_core.views import ( + weixin_work_cb, weixin_work_3rd_app_install, weixin_work_3rd_app_install_cb +) + +urlpatterns += [ + url(r'^weixin-work/callback/$', weixin_work_cb), + url(r'^weixin-work/3rd-app-install/$', weixin_work_3rd_app_install), + url(r'^weixin-work/3rd-app-install/callback/$', + weixin_work_3rd_app_install_cb, name='weixin_work_3rd_app_install_cb'), +] diff --git a/seahub/utils/auth.py b/seahub/utils/auth.py index 0073f5d5ab..5155b078df 100644 --- a/seahub/utils/auth.py +++ b/seahub/utils/auth.py @@ -1,5 +1,6 @@ import os from seahub.settings import LOGIN_BG_IMAGE_PATH, MEDIA_ROOT +from seahub.utils import gen_token def get_login_bg_image_path(): """ Return custom background image path if it exists, otherwise return default background image path. @@ -15,3 +16,6 @@ def get_custom_login_bg_image_path(): """ Ensure consistency between utils and api. """ return 'custom/login-bg.jpg' + +def gen_user_virtual_id(): + return gen_token(max_length=32) + '@auth.local' diff --git a/seahub/utils/urls.py b/seahub/utils/urls.py new file mode 100644 index 0000000000..8f453ffca6 --- /dev/null +++ b/seahub/utils/urls.py @@ -0,0 +1,7 @@ +from django.core.urlresolvers import reverse + +from seahub.utils import get_site_scheme_and_netloc + +def abs_reverse(viewname, urlconf=None, args=None, kwargs=None, current_app=None): + return get_site_scheme_and_netloc().rstrip('/') + reverse( + viewname, urlconf, args, kwargs, current_app) diff --git a/tests/seahub/social_core/__init__.py b/tests/seahub/social_core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/seahub/social_core/test_views.py b/tests/seahub/social_core/test_views.py new file mode 100644 index 0000000000..ab624da847 --- /dev/null +++ b/tests/seahub/social_core/test_views.py @@ -0,0 +1,22 @@ +import os +import pytest + +from seahub.test_utils import BaseTestCase + +TRAVIS = 'TRAVIS' in os.environ + + +class WeixinWorkCBTest(BaseTestCase): + @pytest.mark.skipif(TRAVIS, reason="This test can only be run in local.") + def test_get(self, ): + resp = self.client.get('/weixin-work/callback/?msg_signature=61a7d120857cdb70d8b936ec5b6e8ed172a41926×tamp=1543304575&nonce=1542460575&echostr=9uB%2FReg5PQk%2FjzejPjhjWmvKXuxh0R4VK7BJRP62lfRj5kZhuAu0mLMM7hnREJQTJxWWw3Y1BB%2F%2FLkE3V88auA%3D%3D') + assert resp.content == '6819653789729882111' + + @pytest.mark.skipif(TRAVIS, reason="This test can only be run in local.") + def test_post(self, ): + data = '' + resp = self.client.post( + '/weixin-work/callback/?msg_signature=a237bf482cc9ae8424010eb63a24859c731b2aa7×tamp=1543309590&nonce=1542845878', + data=data, + content_type='application/xml', + ) diff --git a/thirdpart/social_django/__init__.py b/thirdpart/social_django/__init__.py new file mode 100644 index 0000000000..5041050ce8 --- /dev/null +++ b/thirdpart/social_django/__init__.py @@ -0,0 +1,23 @@ +__version__ = '2.1.0' + + +from social_core.backends.base import BaseAuth + +# django.contrib.auth.load_backend() will import and instanciate the +# authentication backend ignoring the possibility that it might +# require more arguments. Here we set a monkey patch to +# BaseAuth.__init__ to ignore the mandatory strategy argument and load +# it. + +def baseauth_init_workaround(original_init): + def fake_init(self, strategy=None, *args, **kwargs): + from .utils import load_strategy + original_init(self, strategy or load_strategy(), *args, **kwargs) + return fake_init + + +if not getattr(BaseAuth, '__init_patched', False): + BaseAuth.__init__ = baseauth_init_workaround(BaseAuth.__init__) + BaseAuth.__init_patched = True + +default_app_config = 'social_django.config.PythonSocialAuthConfig' diff --git a/thirdpart/social_django/admin.py b/thirdpart/social_django/admin.py new file mode 100644 index 0000000000..635d3d821b --- /dev/null +++ b/thirdpart/social_django/admin.py @@ -0,0 +1,62 @@ +"""Admin settings""" +from itertools import chain + +from django.conf import settings +from django.contrib import admin + +from social_core.utils import setting_name +from .models import UserSocialAuth, Nonce, Association + + +class UserSocialAuthOption(admin.ModelAdmin): + """Social Auth user options""" + list_display = ('user', 'id', 'provider', 'uid') + list_filter = ('provider',) + raw_id_fields = ('user',) + list_select_related = True + + def get_search_fields(self, request=None): + search_fields = getattr( + settings, setting_name('ADMIN_USER_SEARCH_FIELDS'), None + ) + if search_fields is None: + _User = UserSocialAuth.user_model() + username = getattr(_User, 'USERNAME_FIELD', None) or \ + hasattr(_User, 'username') and 'username' or \ + None + fieldnames = ('first_name', 'last_name', 'email', username) + all_names = self._get_all_field_names(_User._meta) + search_fields = [name for name in fieldnames + if name and name in all_names] + return ['user__' + name for name in search_fields] + \ + getattr(settings, setting_name('ADMIN_SEARCH_FIELDS'), []) + + @staticmethod + def _get_all_field_names(model): + names = chain.from_iterable( + (field.name, field.attname) + if hasattr(field, 'attname') else (field.name,) + for field in model.get_fields() + # For complete backwards compatibility, you may want to exclude + # GenericForeignKey from the results. + if not (field.many_to_one and field.related_model is None) + ) + return list(set(names)) + + +class NonceOption(admin.ModelAdmin): + """Nonce options""" + list_display = ('id', 'server_url', 'timestamp', 'salt') + search_fields = ('server_url',) + + +class AssociationOption(admin.ModelAdmin): + """Association options""" + list_display = ('id', 'server_url', 'assoc_type') + list_filter = ('assoc_type',) + search_fields = ('server_url',) + + +admin.site.register(UserSocialAuth, UserSocialAuthOption) +admin.site.register(Nonce, NonceOption) +admin.site.register(Association, AssociationOption) diff --git a/thirdpart/social_django/compat.py b/thirdpart/social_django/compat.py new file mode 100644 index 0000000000..4789849cac --- /dev/null +++ b/thirdpart/social_django/compat.py @@ -0,0 +1,34 @@ +# coding=utf-8 +import six +import django +from django.db import models + + +try: + from django.urls import reverse +except ImportError: + from django.core.urlresolvers import reverse + +try: + from django.utils.deprecation import MiddlewareMixin +except ImportError: + MiddlewareMixin = object + + +def get_rel_model(field): + if django.VERSION >= (2, 0): + return field.remote_field.model + + user_model = field.rel.to + if isinstance(user_model, six.string_types): + app_label, model_name = user_model.split('.') + user_model = models.get_model(app_label, model_name) + return user_model + + +def get_request_port(request): + if django.VERSION >= (1, 9): + return request.get_port() + + host_parts = request.get_host().partition(':') + return host_parts[2] or request.META['SERVER_PORT'] diff --git a/thirdpart/social_django/config.py b/thirdpart/social_django/config.py new file mode 100644 index 0000000000..c1491bb184 --- /dev/null +++ b/thirdpart/social_django/config.py @@ -0,0 +1,10 @@ +from django.apps import AppConfig + + +class PythonSocialAuthConfig(AppConfig): + # Full Python path to the application eg. 'django.contrib.admin'. + name = 'social_django' + # Last component of the Python path to the application eg. 'admin'. + label = 'social_django' + # Human-readable name for the application eg. "Admin". + verbose_name = 'Python Social Auth' diff --git a/thirdpart/social_django/context_processors.py b/thirdpart/social_django/context_processors.py new file mode 100644 index 0000000000..07e875c521 --- /dev/null +++ b/thirdpart/social_django/context_processors.py @@ -0,0 +1,52 @@ +from django.contrib.auth import REDIRECT_FIELD_NAME +from django.utils.functional import SimpleLazyObject +from django.utils.http import urlquote + +try: + from django.utils.functional import empty as _empty + empty = _empty +except ImportError: # django < 1.4 + empty = None + + +from social_core.backends.utils import user_backends_data +from .utils import Storage, BACKENDS + + +class LazyDict(SimpleLazyObject): + """Lazy dict initialization.""" + def __getitem__(self, name): + if self._wrapped is empty: + self._setup() + return self._wrapped[name] + + def __setitem__(self, name, value): + if self._wrapped is empty: + self._setup() + self._wrapped[name] = value + + +def backends(request): + """Load Social Auth current user data to context under the key 'backends'. + Will return the output of social_core.backends.utils.user_backends_data.""" + return {'backends': LazyDict(lambda: user_backends_data(request.user, + BACKENDS, + Storage))} + + +def login_redirect(request): + """Load current redirect to context.""" + value = request.method == 'POST' and \ + request.POST.get(REDIRECT_FIELD_NAME) or \ + request.GET.get(REDIRECT_FIELD_NAME) + if value: + value = urlquote(value) + querystring = REDIRECT_FIELD_NAME + '=' + value + else: + querystring = '' + + return { + 'REDIRECT_FIELD_NAME': REDIRECT_FIELD_NAME, + 'REDIRECT_FIELD_VALUE': value, + 'REDIRECT_QUERYSTRING': querystring + } diff --git a/thirdpart/social_django/fields.py b/thirdpart/social_django/fields.py new file mode 100644 index 0000000000..d547ce8e9f --- /dev/null +++ b/thirdpart/social_django/fields.py @@ -0,0 +1,94 @@ +import json +import six +import functools + +import django + +from django.core.exceptions import ValidationError +from django.conf import settings +from django.db import models + +from social_core.utils import setting_name + +try: + from django.utils.encoding import smart_unicode as smart_text + smart_text # placate pyflakes +except ImportError: + from django.utils.encoding import smart_text + +# SubfieldBase causes RemovedInDjango110Warning in 1.8 and 1.9, and +# will not work in 1.10 or later +if django.VERSION[:2] >= (1, 8): + field_metaclass = type +else: + from django.db.models import SubfieldBase + field_metaclass = SubfieldBase + +field_class = functools.partial(six.with_metaclass, field_metaclass) + +if getattr(settings, setting_name('POSTGRES_JSONFIELD'), False): + from django.contrib.postgres.fields import JSONField as JSONFieldBase +else: + JSONFieldBase = field_class(models.TextField) + + +class JSONField(JSONFieldBase): + """Simple JSON field that stores python structures as JSON strings + on database. + """ + + def __init__(self, *args, **kwargs): + kwargs.setdefault('default', dict) + super(JSONField, self).__init__(*args, **kwargs) + + def from_db_value(self, value, expression, connection, context): + return self.to_python(value) + + def to_python(self, value): + """ + Convert the input JSON value into python structures, raises + django.core.exceptions.ValidationError if the data can't be converted. + """ + if self.blank and not value: + return {} + value = value or '{}' + if isinstance(value, six.binary_type): + value = six.text_type(value, 'utf-8') + if isinstance(value, six.string_types): + try: + # with django 1.6 i have '"{}"' as default value here + if value[0] == value[-1] == '"': + value = value[1:-1] + + return json.loads(value) + except Exception as err: + raise ValidationError(str(err)) + else: + return value + + def validate(self, value, model_instance): + """Check value is a valid JSON string, raise ValidationError on + error.""" + if isinstance(value, six.string_types): + super(JSONField, self).validate(value, model_instance) + try: + json.loads(value) + except Exception as err: + raise ValidationError(str(err)) + + def get_prep_value(self, value): + """Convert value to JSON string before save""" + try: + return json.dumps(value) + except Exception as err: + raise ValidationError(str(err)) + + def value_to_string(self, obj): + """Return value from object converted to string properly""" + return smart_text(self.value_from_object(obj)) + + def value_from_object(self, obj): + """Return value dumped to string.""" + orig_val = super(JSONField, self).value_from_object(obj) + return self.get_prep_value(orig_val) + diff --git a/thirdpart/social_django/management/__init__.py b/thirdpart/social_django/management/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/thirdpart/social_django/management/commands/__init__.py b/thirdpart/social_django/management/commands/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/thirdpart/social_django/management/commands/clearsocial.py b/thirdpart/social_django/management/commands/clearsocial.py new file mode 100644 index 0000000000..fa1533334d --- /dev/null +++ b/thirdpart/social_django/management/commands/clearsocial.py @@ -0,0 +1,35 @@ +from datetime import timedelta + +from django.core.management.base import BaseCommand +from django.utils import timezone + +from social_django.models import Code, Partial + + +class Command(BaseCommand): + help = 'removes old not used verification codes and partials' + + def add_arguments(self, parser): + super(Command, self).add_arguments(parser) + parser.add_argument( + '--age', + action='store', + type=int, + dest='age', + default=14, + help='how long to keep unused data (in days, defaults to 14)' + ) + + def handle(self, *args, **options): + age = timezone.now() - timedelta(days=options['age']) + + # Delete old not verified codes + Code.objects.filter( + verified=False, + timestamp__lt=age + ).delete() + + # Delete old partial data + Partial.objects.filter( + timestamp__lt=age + ).delete() diff --git a/thirdpart/social_django/managers.py b/thirdpart/social_django/managers.py new file mode 100644 index 0000000000..1fa91b68f7 --- /dev/null +++ b/thirdpart/social_django/managers.py @@ -0,0 +1,15 @@ +from django.db import models + + +class UserSocialAuthManager(models.Manager): + """Manager for the UserSocialAuth django model.""" + + class Meta: + app_label = "social_django" + + def get_social_auth(self, provider, uid): + try: + return self.select_related('user').get(provider=provider, + uid=uid) + except self.model.DoesNotExist: + return None diff --git a/thirdpart/social_django/middleware.py b/thirdpart/social_django/middleware.py new file mode 100644 index 0000000000..fd0adc4be5 --- /dev/null +++ b/thirdpart/social_django/middleware.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +import six + +from django.apps import apps +from django.conf import settings +from django.contrib import messages +from django.contrib.messages.api import MessageFailure +from django.core.urlresolvers import reverse +from django.shortcuts import redirect +from django.utils.http import urlquote + +from social_core.exceptions import SocialAuthBaseException +from social_core.utils import social_logger +from .compat import MiddlewareMixin + + +class SocialAuthExceptionMiddleware(MiddlewareMixin): + """Middleware that handles Social Auth AuthExceptions by providing the user + with a message, logging an error, and redirecting to some next location. + + By default, the exception message itself is sent to the user and they are + redirected to the location specified in the SOCIAL_AUTH_LOGIN_ERROR_URL + setting. + + This middleware can be extended by overriding the get_message or + get_redirect_uri methods, which each accept request and exception. + """ + def process_exception(self, request, exception): + strategy = getattr(request, 'social_strategy', None) + if strategy is None or self.raise_exception(request, exception): + return + + if isinstance(exception, SocialAuthBaseException): + backend = getattr(request, 'backend', None) + backend_name = getattr(backend, 'name', 'unknown-backend') + + message = self.get_message(request, exception) + url = self.get_redirect_uri(request, exception) + + if apps.is_installed('django.contrib.messages'): + social_logger.info(message) + try: + messages.error(request, message, + extra_tags='social-auth ' + backend_name) + except MessageFailure: + if url: + url += ('?' in url and '&' or '?') + \ + 'message={0}&backend={1}'.format(urlquote(message), + backend_name) + else: + social_logger.error(message) + + if url: + return redirect(url) + else: + return redirect(reverse('edit_profile')) + + def raise_exception(self, request, exception): + strategy = getattr(request, 'social_strategy', None) + if strategy is not None: + return strategy.setting('RAISE_EXCEPTIONS') or settings.DEBUG + + def get_message(self, request, exception): + return six.text_type(exception) + + def get_redirect_uri(self, request, exception): + strategy = getattr(request, 'social_strategy', None) + return strategy.setting('LOGIN_ERROR_URL') diff --git a/thirdpart/social_django/migrations/0001_initial.py b/thirdpart/social_django/migrations/0001_initial.py new file mode 100644 index 0000000000..61315ec33d --- /dev/null +++ b/thirdpart/social_django/migrations/0001_initial.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.11 on 2018-05-25 03:27 +from __future__ import unicode_literals + +from django.db import migrations, models +import seahub.base.fields +import social_django.fields +import social_django.storage + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='Association', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('server_url', models.CharField(max_length=255)), + ('handle', models.CharField(max_length=255)), + ('secret', models.CharField(max_length=255)), + ('issued', models.IntegerField()), + ('lifetime', models.IntegerField()), + ('assoc_type', models.CharField(max_length=64)), + ], + options={ + 'db_table': 'social_auth_association', + }, + bases=(models.Model, social_django.storage.DjangoAssociationMixin), + ), + migrations.CreateModel( + name='Code', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('email', models.EmailField(max_length=254)), + ('code', models.CharField(db_index=True, max_length=32)), + ('verified', models.BooleanField(default=False)), + ('timestamp', models.DateTimeField(auto_now_add=True, db_index=True)), + ], + options={ + 'db_table': 'social_auth_code', + }, + bases=(models.Model, social_django.storage.DjangoCodeMixin), + ), + migrations.CreateModel( + name='Nonce', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('server_url', models.CharField(max_length=255)), + ('timestamp', models.IntegerField()), + ('salt', models.CharField(max_length=65)), + ], + options={ + 'db_table': 'social_auth_nonce', + }, + bases=(models.Model, social_django.storage.DjangoNonceMixin), + ), + migrations.CreateModel( + name='Partial', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('token', models.CharField(db_index=True, max_length=32)), + ('next_step', models.PositiveSmallIntegerField(default=0)), + ('backend', models.CharField(max_length=32)), + ('data', social_django.fields.JSONField(default=dict)), + ('timestamp', models.DateTimeField(auto_now_add=True, db_index=True)), + ], + options={ + 'db_table': 'social_auth_partial', + }, + bases=(models.Model, social_django.storage.DjangoPartialMixin), + ), + migrations.CreateModel( + name='UserSocialAuth', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('username', seahub.base.fields.LowerCaseCharField(db_index=True, max_length=255)), + ('provider', models.CharField(max_length=32)), + ('uid', models.CharField(max_length=255)), + ('extra_data', social_django.fields.JSONField(default=dict)), + ], + options={ + 'db_table': 'social_auth_usersocialauth', + }, + bases=(models.Model, social_django.storage.DjangoUserMixin), + ), + migrations.AlterUniqueTogether( + name='usersocialauth', + unique_together=set([('provider', 'uid')]), + ), + migrations.AlterUniqueTogether( + name='nonce', + unique_together=set([('server_url', 'timestamp', 'salt')]), + ), + migrations.AlterUniqueTogether( + name='code', + unique_together=set([('email', 'code')]), + ), + migrations.AlterUniqueTogether( + name='association', + unique_together=set([('server_url', 'handle')]), + ), + ] diff --git a/thirdpart/social_django/migrations/0002_auto_20181115_0825.py b/thirdpart/social_django/migrations/0002_auto_20181115_0825.py new file mode 100644 index 0000000000..48d76a3d93 --- /dev/null +++ b/thirdpart/social_django/migrations/0002_auto_20181115_0825.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.15 on 2018-11-15 08:25 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('social_django', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='usersocialauth', + name='uid', + field=models.CharField(max_length=150), + ), + ] diff --git a/thirdpart/social_django/migrations/__init__.py b/thirdpart/social_django/migrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/thirdpart/social_django/models.py b/thirdpart/social_django/models.py new file mode 100644 index 0000000000..15f11e6a6b --- /dev/null +++ b/thirdpart/social_django/models.py @@ -0,0 +1,150 @@ +"""Django ORM models for Social Auth""" +import six + +from django.db import models +from django.conf import settings +from django.db.utils import IntegrityError + +from social_core.utils import setting_name +from seahub.base.accounts import User +from seahub.base.fields import LowerCaseCharField + +from .compat import get_rel_model +from .storage import DjangoUserMixin, DjangoAssociationMixin, \ + DjangoNonceMixin, DjangoCodeMixin, \ + DjangoPartialMixin, BaseDjangoStorage +from .fields import JSONField +from .managers import UserSocialAuthManager + +USER_MODEL = getattr(settings, setting_name('USER_MODEL'), None) or \ + getattr(settings, 'AUTH_USER_MODEL', None) or \ + 'auth.User' +UID_LENGTH = getattr(settings, setting_name('UID_LENGTH'), 150) +EMAIL_LENGTH = getattr(settings, setting_name('EMAIL_LENGTH'), 254) +NONCE_SERVER_URL_LENGTH = getattr( + settings, setting_name('NONCE_SERVER_URL_LENGTH'), 255) +ASSOCIATION_SERVER_URL_LENGTH = getattr( + settings, setting_name('ASSOCIATION_SERVER_URL_LENGTH'), 255) +ASSOCIATION_HANDLE_LENGTH = getattr( + settings, setting_name('ASSOCIATION_HANDLE_LENGTH'), 255) + + +class AbstractUserSocialAuth(models.Model, DjangoUserMixin): + """Abstract Social Auth association model""" + # user = models.ForeignKey(USER_MODEL, related_name='social_auth', + # on_delete=models.CASCADE) + username = LowerCaseCharField(max_length=255, db_index=True) + provider = models.CharField(max_length=32) + uid = models.CharField(max_length=UID_LENGTH) + extra_data = JSONField() + objects = UserSocialAuthManager() + + def __str__(self): + return str(self.username) + + class Meta: + app_label = "social_django" + abstract = True + + @classmethod + def get_social_auth(cls, provider, uid): + try: + social_auth = cls.objects.get(provider=provider, uid=uid) + except cls.DoesNotExist: + return None + + try: + u = User.objects.get(email=social_auth.username) + social_auth.user = u + except User.DoesNotExist: + social_auth.user = None + + return social_auth + + @classmethod + def username_max_length(cls): + return 255 + # username_field = cls.username_field() + # field = cls.user_model()._meta.get_field(username_field) + # return field.max_length + + @classmethod + def user_model(cls): + return User + # user_model = get_rel_model(field=cls._meta.get_field('user')) + # return user_model + + +class UserSocialAuth(AbstractUserSocialAuth): + """Social Auth association model""" + + class Meta: + """Meta data""" + app_label = "social_django" + unique_together = ('provider', 'uid') + db_table = 'social_auth_usersocialauth' + + +class Nonce(models.Model, DjangoNonceMixin): + """One use numbers""" + server_url = models.CharField(max_length=NONCE_SERVER_URL_LENGTH) + timestamp = models.IntegerField() + salt = models.CharField(max_length=65) + + class Meta: + app_label = "social_django" + unique_together = ('server_url', 'timestamp', 'salt') + db_table = 'social_auth_nonce' + + +class Association(models.Model, DjangoAssociationMixin): + """OpenId account association""" + server_url = models.CharField(max_length=ASSOCIATION_SERVER_URL_LENGTH) + handle = models.CharField(max_length=ASSOCIATION_HANDLE_LENGTH) + secret = models.CharField(max_length=255) # Stored base64 encoded + issued = models.IntegerField() + lifetime = models.IntegerField() + assoc_type = models.CharField(max_length=64) + + class Meta: + app_label = "social_django" + db_table = 'social_auth_association' + unique_together = ( + ('server_url', 'handle',) + ) + + +class Code(models.Model, DjangoCodeMixin): + email = models.EmailField(max_length=EMAIL_LENGTH) + code = models.CharField(max_length=32, db_index=True) + verified = models.BooleanField(default=False) + timestamp = models.DateTimeField(auto_now_add=True, db_index=True) + + class Meta: + app_label = "social_django" + db_table = 'social_auth_code' + unique_together = ('email', 'code') + + +class Partial(models.Model, DjangoPartialMixin): + token = models.CharField(max_length=32, db_index=True) + next_step = models.PositiveSmallIntegerField(default=0) + backend = models.CharField(max_length=32) + data = JSONField() + timestamp = models.DateTimeField(auto_now_add=True, db_index=True) + + class Meta: + app_label = "social_django" + db_table = 'social_auth_partial' + + +class DjangoStorage(BaseDjangoStorage): + user = UserSocialAuth + nonce = Nonce + association = Association + code = Code + partial = Partial + + @classmethod + def is_integrity_error(cls, exception): + return exception.__class__ is IntegrityError diff --git a/thirdpart/social_django/storage.py b/thirdpart/social_django/storage.py new file mode 100644 index 0000000000..60d68fe5df --- /dev/null +++ b/thirdpart/social_django/storage.py @@ -0,0 +1,220 @@ +"""Django ORM models for Social Auth""" +import base64 +import six +import sys +from django.db import transaction +from django.db.utils import IntegrityError + +from social_core.storage import UserMixin, AssociationMixin, NonceMixin, \ + CodeMixin, PartialMixin, BaseStorage +from seahub.base.accounts import User + + +class DjangoUserMixin(UserMixin): + """Social Auth association model""" + @classmethod + def changed(cls, user): + user.save() + + def set_extra_data(self, extra_data=None): + if super(DjangoUserMixin, self).set_extra_data(extra_data): + self.save() + + @classmethod + def allowed_to_disconnect(cls, user, backend_name, association_id=None): + if association_id is not None: + qs = cls.objects.exclude(id=association_id) + else: + qs = cls.objects.exclude(provider=backend_name) + qs = qs.filter(username=user.username) + + if hasattr(user, 'has_usable_password'): + valid_password = user.has_usable_password() + else: + valid_password = True + return valid_password or qs.count() > 0 + + @classmethod + def disconnect(cls, entry): + entry.delete() + + @classmethod + def username_field(cls): + return 'username' + # return getattr(cls.user_model(), 'USERNAME_FIELD', 'username') + + @classmethod + def user_exists(cls, *args, **kwargs): + """ + Return True/False if a User instance exists with the given arguments. + Arguments are directly passed to filter() manager method. + """ + if 'username' in kwargs: + kwargs[cls.username_field()] = kwargs.pop('username') + + assert 'username' in kwargs + + try: + User.objects.get(email=kwargs['username']) + return True + except User.DoesNotExist: + return False + # return cls.user_model().objects.filter(*args, **kwargs).count() > 0 + + @classmethod + def get_username(cls, user): + return getattr(user, cls.username_field(), None) + + @classmethod + def create_user(cls, *args, **kwargs): + username_field = cls.username_field() + if 'username' in kwargs and username_field not in kwargs: + kwargs[username_field] = kwargs.pop('username') + + assert 'username' in kwargs + + user = User.objects.create_user(email=kwargs['username'], + is_active=True, + save_profile=False) + + # try: + # if hasattr(transaction, 'atomic'): + # # In Django versions that have an "atomic" transaction decorator / context + # # manager, there's a transaction wrapped around this call. + # # If the create fails below due to an IntegrityError, ensure that the transaction + # # stays undamaged by wrapping the create in an atomic. + # with transaction.atomic(): + # user = cls.user_model().objects.create_user(*args, **kwargs) + # else: + # user = cls.user_model().objects.create_user(*args, **kwargs) + # except IntegrityError: + # # User might have been created on a different thread, try and find them. + # # If we don't, re-raise the IntegrityError. + # exc_info = sys.exc_info() + # # If email comes in as None it won't get found in the get + # if kwargs.get('email', True) is None: + # kwargs['email'] = '' + # try: + # user = cls.user_model().objects.get(*args, **kwargs) + # except cls.user_model().DoesNotExist: + # six.reraise(*exc_info) + return user + + @classmethod + def get_user(cls, pk=None, **kwargs): + if pk: + kwargs = {'pk': pk} + + try: + return User.objects.get(email=pk) + except User.DoesNotExist: + return None + # try: + # return cls.user_model().objects.get(**kwargs) + # except cls.user_model().DoesNotExist: + # return None + + @classmethod + def get_users_by_email(cls, email): + user_model = cls.user_model() + email_field = getattr(user_model, 'EMAIL_FIELD', 'email') + return user_model.objects.filter(**{email_field + '__iexact': email}) + + @classmethod + def get_social_auth(cls, provider, uid): + if not isinstance(uid, six.string_types): + uid = str(uid) + try: + return cls.objects.get(provider=provider, uid=uid) + except cls.DoesNotExist: + return None + + @classmethod + def get_social_auth_for_user(cls, user, provider=None, id=None): + qs = cls.objects.filter(username=user.username) + + if provider: + qs = qs.filter(provider=provider) + + if id: + qs = qs.filter(id=id) + return qs + + @classmethod + def create_social_auth(cls, user, uid, provider): + if not isinstance(uid, six.string_types): + uid = str(uid) + if hasattr(transaction, 'atomic'): + # In Django versions that have an "atomic" transaction decorator / context + # manager, there's a transaction wrapped around this call. + # If the create fails below due to an IntegrityError, ensure that the transaction + # stays undamaged by wrapping the create in an atomic. + with transaction.atomic(): + social_auth = cls.objects.create(username=user.username, uid=uid, provider=provider) + else: + social_auth = cls.objects.create(username=user.username, uid=uid, provider=provider) + return social_auth + + +class DjangoNonceMixin(NonceMixin): + @classmethod + def use(cls, server_url, timestamp, salt): + return cls.objects.get_or_create(server_url=server_url, + timestamp=timestamp, + salt=salt)[1] + + +class DjangoAssociationMixin(AssociationMixin): + @classmethod + def store(cls, server_url, association): + # Don't use get_or_create because issued cannot be null + try: + assoc = cls.objects.get(server_url=server_url, + handle=association.handle) + except cls.DoesNotExist: + assoc = cls(server_url=server_url, + handle=association.handle) + assoc.secret = base64.encodestring(association.secret) + assoc.issued = association.issued + assoc.lifetime = association.lifetime + assoc.assoc_type = association.assoc_type + assoc.save() + + @classmethod + def get(cls, *args, **kwargs): + return cls.objects.filter(*args, **kwargs) + + @classmethod + def remove(cls, ids_to_delete): + cls.objects.filter(pk__in=ids_to_delete).delete() + + +class DjangoCodeMixin(CodeMixin): + @classmethod + def get_code(cls, code): + try: + return cls.objects.get(code=code) + except cls.DoesNotExist: + return None + + +class DjangoPartialMixin(PartialMixin): + @classmethod + def load(cls, token): + try: + return cls.objects.get(token=token) + except cls.DoesNotExist: + return None + + @classmethod + def destroy(cls, token): + partial = cls.load(token) + if partial: + partial.delete() + + +class BaseDjangoStorage(BaseStorage): + user = DjangoUserMixin + nonce = DjangoNonceMixin + association = DjangoAssociationMixin + code = DjangoCodeMixin diff --git a/thirdpart/social_django/strategy.py b/thirdpart/social_django/strategy.py new file mode 100644 index 0000000000..1a3a820afa --- /dev/null +++ b/thirdpart/social_django/strategy.py @@ -0,0 +1,159 @@ +# coding=utf-8 +from django.conf import settings +from django.http import HttpResponse, HttpRequest +from django.db.models import Model +from django.contrib.contenttypes.models import ContentType +from django.contrib.auth import authenticate +from django.shortcuts import redirect, resolve_url +from django.template import TemplateDoesNotExist, loader, engines +from django.utils.crypto import get_random_string +from django.utils.encoding import force_text +from django.utils.functional import Promise +from django.utils.translation import get_language + +from social_core.strategy import BaseStrategy, BaseTemplateStrategy +from .compat import get_request_port + + +def render_template_string(request, html, context=None): + """Take a template in the form of a string and render it for the + given context""" + template = engines['django'].from_string(html) + return template.render(context=context, request=request) + + +class DjangoTemplateStrategy(BaseTemplateStrategy): + def render_template(self, tpl, context): + template = loader.get_template(tpl) + return template.render(context=context, request=self.strategy.request) + + def render_string(self, html, context): + return render_template_string(self.strategy.request, html, context) + + +class DjangoStrategy(BaseStrategy): + DEFAULT_TEMPLATE_STRATEGY = DjangoTemplateStrategy + + def __init__(self, storage, request=None, tpl=None): + self.request = request + self.session = request.session if request else {} + super(DjangoStrategy, self).__init__(storage, tpl) + + def get_setting(self, name): + value = getattr(settings, name) + # Force text on URL named settings that are instance of Promise + if name.endswith('_URL'): + if isinstance(value, Promise): + value = force_text(value) + value = resolve_url(value) + return value + + def request_data(self, merge=True): + if not self.request: + return {} + if merge: + data = self.request.GET.copy() + data.update(self.request.POST) + elif self.request.method == 'POST': + data = self.request.POST + else: + data = self.request.GET + return data + + def request_host(self): + if self.request: + return self.request.get_host() + + def request_is_secure(self): + """Is the request using HTTPS?""" + return self.request.is_secure() + + def request_path(self): + """path of the current request""" + return self.request.path + + def request_port(self): + """Port in use for this request""" + return get_request_port(request=self.request) + + def request_get(self): + """Request GET data""" + return self.request.GET.copy() + + def request_post(self): + """Request POST data""" + return self.request.POST.copy() + + def redirect(self, url): + return redirect(url) + + def html(self, content): + return HttpResponse(content, content_type='text/html;charset=UTF-8') + + def render_html(self, tpl=None, html=None, context=None): + if not tpl and not html: + raise ValueError('Missing template or html parameters') + context = context or {} + try: + template = loader.get_template(tpl) + return template.render(context=context, request=self.request) + except TemplateDoesNotExist: + return render_template_string(self.request, html, context) + + def authenticate(self, backend, *args, **kwargs): + kwargs['strategy'] = self + kwargs['storage'] = self.storage + kwargs['backend'] = backend + return authenticate(*args, **kwargs) + + def clean_authenticate_args(self, *args, **kwargs): + """Cleanup request argument if present, which is passed to + authenticate as for Django 1.11""" + if len(args) > 0 and isinstance(args[0], HttpRequest): + kwargs['request'], args = args[0], args[1:] + return args, kwargs + + def session_get(self, name, default=None): + return self.session.get(name, default) + + def session_set(self, name, value): + self.session[name] = value + if hasattr(self.session, 'modified'): + self.session.modified = True + + def session_pop(self, name): + return self.session.pop(name, None) + + def session_setdefault(self, name, value): + return self.session.setdefault(name, value) + + def build_absolute_uri(self, path=None): + if self.request: + return self.request.build_absolute_uri(path) + else: + return path + + def random_string(self, length=12, chars=BaseStrategy.ALLOWED_CHARS): + return get_random_string(length, chars) + + def to_session_value(self, val): + """Converts values that are instance of Model to a dictionary + with enough information to retrieve the instance back later.""" + if isinstance(val, Model): + val = { + 'pk': val.pk, + 'ctype': ContentType.objects.get_for_model(val).pk + } + return val + + def from_session_value(self, val): + """Converts back the instance saved by self._ctype function.""" + if isinstance(val, dict) and 'pk' in val and 'ctype' in val: + ctype = ContentType.objects.get_for_id(val['ctype']) + ModelClass = ctype.model_class() + val = ModelClass.objects.get(pk=val['pk']) + return val + + def get_language(self): + """Return current language""" + return get_language() diff --git a/thirdpart/social_django/urls.py b/thirdpart/social_django/urls.py new file mode 100644 index 0000000000..68a07c7273 --- /dev/null +++ b/thirdpart/social_django/urls.py @@ -0,0 +1,24 @@ +"""URLs module""" +from django.conf import settings +from django.conf.urls import url + +from social_core.utils import setting_name +from . import views + + +extra = getattr(settings, setting_name('TRAILING_SLASH'), True) and '/' or '' + +app_name = 'social' + +urlpatterns = [ + # authentication / association + url(r'^login/(?P[^/]+){0}$'.format(extra), views.auth, + name='begin'), + url(r'^complete/(?P[^/]+){0}$'.format(extra), views.complete, + name='complete'), + # disconnection + url(r'^disconnect/(?P[^/]+){0}$'.format(extra), views.disconnect, + name='disconnect'), + url(r'^disconnect/(?P[^/]+)/(?P\d+){0}$' + .format(extra), views.disconnect, name='disconnect_individual'), +] diff --git a/thirdpart/social_django/utils.py b/thirdpart/social_django/utils.py new file mode 100644 index 0000000000..281bdd49f6 --- /dev/null +++ b/thirdpart/social_django/utils.py @@ -0,0 +1,51 @@ +# coding=utf-8 +from functools import wraps + +from django.conf import settings +from django.http import Http404 + +from social_core.utils import setting_name, module_member, get_strategy +from social_core.exceptions import MissingBackend +from social_core.backends.utils import get_backend +from .compat import reverse + + +BACKENDS = settings.AUTHENTICATION_BACKENDS +STRATEGY = getattr(settings, setting_name('STRATEGY'), + 'social_django.strategy.DjangoStrategy') +STORAGE = getattr(settings, setting_name('STORAGE'), + 'social_django.models.DjangoStorage') +Strategy = module_member(STRATEGY) +Storage = module_member(STORAGE) + + +def load_strategy(request=None): + return get_strategy(STRATEGY, STORAGE, request) + + +def load_backend(strategy, name, redirect_uri): + Backend = get_backend(BACKENDS, name) + return Backend(strategy, redirect_uri) + + +def psa(redirect_uri=None, load_strategy=load_strategy): + def decorator(func): + @wraps(func) + def wrapper(request, backend, *args, **kwargs): + uri = redirect_uri + if uri and not uri.startswith('/'): + uri = reverse(redirect_uri, args=(backend,)) + request.social_strategy = load_strategy(request) + # backward compatibility in attribute name, only if not already + # defined + if not hasattr(request, 'strategy'): + request.strategy = request.social_strategy + + try: + request.backend = load_backend(request.social_strategy, + backend, uri) + except MissingBackend: + raise Http404('Backend not found') + return func(request, backend, *args, **kwargs) + return wrapper + return decorator diff --git a/thirdpart/social_django/views.py b/thirdpart/social_django/views.py new file mode 100644 index 0000000000..a43008ef53 --- /dev/null +++ b/thirdpart/social_django/views.py @@ -0,0 +1,131 @@ +from django.conf import settings +from django.contrib.auth import REDIRECT_FIELD_NAME +from django.contrib.auth.decorators import login_required +from django.views.decorators.csrf import csrf_exempt, csrf_protect +from django.views.decorators.http import require_POST +from django.views.decorators.cache import never_cache + +from seahub.auth import login + +from social_core.utils import setting_name +from social_core.actions import do_auth, do_complete, do_disconnect +from .utils import psa + + +NAMESPACE = getattr(settings, setting_name('URL_NAMESPACE'), None) or 'social' + +# Calling `session.set_expiry(None)` results in a session lifetime equal to +# platform default session lifetime. +DEFAULT_SESSION_TIMEOUT = None + + +@never_cache +@psa('{0}:complete'.format(NAMESPACE)) +def auth(request, backend): + return do_auth(request.backend, redirect_name=REDIRECT_FIELD_NAME) + + +@never_cache +@csrf_exempt +@psa('{0}:complete'.format(NAMESPACE)) +def complete(request, backend, *args, **kwargs): + """Authentication complete view""" + return do_complete(request.backend, _do_login, request.user, + redirect_name=REDIRECT_FIELD_NAME, request=request, + *args, **kwargs) + + +@never_cache +@login_required +@psa() +@require_POST +@csrf_protect +def disconnect(request, backend, association_id=None): + """Disconnects given backend from current logged in user.""" + return do_disconnect(request.backend, request.user, association_id, + redirect_name=REDIRECT_FIELD_NAME) + + +def get_session_timeout(social_user, enable_session_expiration=False, + max_session_length=None): + if enable_session_expiration: + # Retrieve an expiration date from the social user who just finished + # logging in; this value was set by the social auth backend, and was + # typically received from the server. + expiration = social_user.expiration_datetime() + + # We've enabled session expiration. Check to see if we got + # a specific expiration time from the provider for this user; + # if not, use the platform default expiration. + if expiration: + received_expiration_time = expiration.total_seconds() + else: + received_expiration_time = DEFAULT_SESSION_TIMEOUT + + # Check to see if the backend set a value as a maximum length + # that a session may be; if they did, then we should use the minimum + # of that and the received session expiration time, if any, to + # set the session length. + if received_expiration_time is None and max_session_length is None: + # We neither received an expiration length, nor have a maximum + # session length. Use the platform default. + session_expiry = DEFAULT_SESSION_TIMEOUT + elif received_expiration_time is None and max_session_length is not None: + # We only have a maximum session length; use that. + session_expiry = max_session_length + elif received_expiration_time is not None and max_session_length is None: + # We only have an expiration time received by the backend + # from the provider, with no set maximum. Use that. + session_expiry = received_expiration_time + else: + # We received an expiration time from the backend, and we also + # have a set maximum session length. Use the smaller of the two. + session_expiry = min(received_expiration_time, max_session_length) + else: + # If there's an explicitly-set maximum session length, use that + # even if we don't want to retrieve session expiry times from + # the backend. If there isn't, then use the platform default. + if max_session_length is None: + session_expiry = DEFAULT_SESSION_TIMEOUT + else: + session_expiry = max_session_length + + return session_expiry + + +def _do_login(backend, user, social_user): + user.backend = '{0}.{1}'.format(backend.__module__, + backend.__class__.__name__) + # Get these details early to avoid any issues involved in the + # session switch that happens when we call login(). + enable_session_expiration = backend.setting('SESSION_EXPIRATION', False) + max_session_length_setting = backend.setting('MAX_SESSION_LENGTH', None) + + # Log the user in, creating a new session. + login(backend.strategy.request, user) + + # Make sure that the max_session_length value is either an integer or + # None. Because we get this as a setting from the backend, it can be set + # to whatever the backend creator wants; we want to be resilient against + # unexpected types being presented to us. + try: + max_session_length = int(max_session_length_setting) + except (TypeError, ValueError): + # We got a response that doesn't look like a number; use the default. + max_session_length = None + + # Get the session expiration length based on the maximum session length + # setting, combined with any session length received from the backend. + session_expiry = get_session_timeout( + social_user, + enable_session_expiration=enable_session_expiration, + max_session_length=max_session_length, + ) + + try: + # Set the session length to our previously determined expiry length. + backend.strategy.request.session.set_expiry(session_expiry) + except OverflowError: + # The timestamp we used wasn't in the range of values supported by + # Django for session length; use the platform default. We tried. + backend.strategy.request.session.set_expiry(DEFAULT_SESSION_TIMEOUT) diff --git a/thirdpart/weworkapi/AbstractApi.py b/thirdpart/weworkapi/AbstractApi.py new file mode 100644 index 0000000000..7dde5ee81b --- /dev/null +++ b/thirdpart/weworkapi/AbstractApi.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +## + # Copyright (C) 2018 All rights reserved. + # + # @File AbstractApi.py + # @Brief + # @Author abelzhu, abelzhu@tencent.com + # @Version 1.0 + # @Date 2018-02-24 + # + # + +import sys +import os +import re + +import json +import requests + +DEBUG = False + +class ApiException(Exception) : + def __init__(self, errCode, errMsg) : + self.errCode = errCode + self.errMsg = errMsg + +class AbstractApi(object) : + def __init__(self) : + return + + def getAccessToken(self) : + raise NotImplementedError + def refreshAccessToken(self) : + raise NotImplementedError + + def getSuiteAccessToken(self) : + raise NotImplementedError + def refreshSuiteAccessToken(self) : + raise NotImplementedError + + def getProviderAccessToken(self) : + raise NotImplementedError + def refreshProviderAccessToken(self) : + raise NotImplementedError + + def httpCall(self, urlType, args=None) : + shortUrl = urlType[0] + method = urlType[1] + response = {} + for retryCnt in range(0, 3) : + if 'POST' == method : + url = self.__makeUrl(shortUrl) + response = self.__httpPost(url, args) + elif 'GET' == method : + url = self.__makeUrl(shortUrl) + url = self.__appendArgs(url, args) + response = self.__httpGet(url) + else : + raise ApiException(-1, "unknown method type") + + # check if token expired + if self.__tokenExpired(response.get('errcode')) : + self.__refreshToken(shortUrl) + retryCnt += 1 + continue + else : + break + + return self.__checkResponse(response) + + @staticmethod + def __appendArgs(url, args) : + if args is None : + return url + + for key, value in args.items() : + if '?' in url : + url += ('&' + key + '=' + value) + else : + url += ('?' + key + '=' + value) + return url + + @staticmethod + def __makeUrl(shortUrl) : + base = "https://qyapi.weixin.qq.com" + if shortUrl[0] == '/' : + return base + shortUrl + else : + return base + '/' + shortUrl + + def __appendToken(self, url) : + if 'SUITE_ACCESS_TOKEN' in url : + return url.replace('SUITE_ACCESS_TOKEN', self.getSuiteAccessToken()) + elif 'PROVIDER_ACCESS_TOKEN' in url : + return url.replace('PROVIDER_ACCESS_TOKEN', self.getProviderAccessToken()) + elif 'ACCESS_TOKEN' in url : + return url.replace('ACCESS_TOKEN', self.getAccessToken()) + else : + return url + + def __httpPost(self, url, args) : + realUrl = self.__appendToken(url) + + if DEBUG is True : + print realUrl, args + + return requests.post(realUrl, data = json.dumps(args, ensure_ascii = False).encode('utf-8')).json() + + def __httpGet(self, url) : + realUrl = self.__appendToken(url) + + if DEBUG is True : + print realUrl + + return requests.get(realUrl).json() + + def __post_file(self, url, media_file): + return requests.post(url, file=media_file).json() + + @staticmethod + def __checkResponse(response): + errCode = response.get('errcode') + errMsg = response.get('errmsg') + + if errCode is 0: + return response + else: + raise ApiException(errCode, errMsg) + + @staticmethod + def __tokenExpired(errCode) : + if errCode == 40014 or errCode == 42001 or errCode == 42007 or errCode == 42009 : + return True + else : + return False + + def __refreshToken(self, url) : + if 'SUITE_ACCESS_TOKEN' in url : + self.refreshSuiteAccessToken() + elif 'PROVIDER_ACCESS_TOKEN' in url : + self.refreshProviderAccessToken() + elif 'ACCESS_TOKEN' in url : + self.refreshAccessToken() diff --git a/thirdpart/weworkapi/CorpApi.py b/thirdpart/weworkapi/CorpApi.py new file mode 100644 index 0000000000..71ef46ae41 --- /dev/null +++ b/thirdpart/weworkapi/CorpApi.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +## + # Copyright (C) 2018 All rights reserved. + # + # @File CorpApi.py + # @Brief + # @Author abelzhu, abelzhu@tencent.com + # @Version 1.0 + # @Date 2018-02-24 + # + # + +from .AbstractApi import * + +CORP_API_TYPE = { + 'GET_ACCESS_TOKEN' : ['/cgi-bin/gettoken', 'GET'], + 'USER_CREATE' : ['/cgi-bin/user/create?access_token=ACCESS_TOKEN', 'POST'], + 'USER_GET' : ['/cgi-bin/user/get?access_token=ACCESS_TOKEN', 'GET'], + 'USER_UPDATE' : ['/cgi-bin/user/update?access_token=ACCESS_TOKEN', 'POST'], + 'USER_DELETE' : ['/cgi-bin/user/delete?access_token=ACCESS_TOKEN', 'GET'], + 'USER_BATCH_DELETE': ['/cgi-bin/user/batchdelete?access_token=ACCESS_TOKEN', 'POST'], + 'USER_SIMPLE_LIST ': ['/cgi-bin/user/simplelist?access_token=ACCESS_TOKEN', 'GET'], + 'USER_LIST' : ['/cgi-bin/user/list?access_token=ACCESS_TOKEN', 'GET'], + 'USERID_TO_OPENID' : ['/cgi-bin/user/convert_to_openid?access_token=ACCESS_TOKEN', 'POST'], + 'OPENID_TO_USERID' : ['/cgi-bin/user/convert_to_userid?access_token=ACCESS_TOKEN', 'POST'], + 'USER_AUTH_SUCCESS': ['/cgi-bin/user/authsucc?access_token=ACCESS_TOKEN', 'GET'], + + 'DEPARTMENT_CREATE': ['/cgi-bin/department/create?access_token=ACCESS_TOKEN', 'POST'], + 'DEPARTMENT_UPDATE': ['/cgi-bin/department/update?access_token=ACCESS_TOKEN', 'POST'], + 'DEPARTMENT_DELETE': ['/cgi-bin/department/delete?access_token=ACCESS_TOKEN', 'GET'], + 'DEPARTMENT_LIST' : ['/cgi-bin/department/list?access_token=ACCESS_TOKEN', 'GET'], + + 'TAG_CREATE' : ['/cgi-bin/tag/create?access_token=ACCESS_TOKEN', 'POST'], + 'TAG_UPDATE' : ['/cgi-bin/tag/update?access_token=ACCESS_TOKEN', 'POST'], + 'TAG_DELETE' : ['/cgi-bin/tag/delete?access_token=ACCESS_TOKEN', 'GET'], + 'TAG_GET_USER' : ['/cgi-bin/tag/get?access_token=ACCESS_TOKEN', 'GET'], + 'TAG_ADD_USER' : ['/cgi-bin/tag/addtagusers?access_token=ACCESS_TOKEN', 'POST'], + 'TAG_DELETE_USER' : ['/cgi-bin/tag/deltagusers?access_token=ACCESS_TOKEN', 'POST'], + 'TAG_GET_LIST' : ['/cgi-bin/tag/list?access_token=ACCESS_TOKEN', 'GET'], + + 'BATCH_JOB_GET_RESULT' : ['/cgi-bin/batch/getresult?access_token=ACCESS_TOKEN', 'GET'], + + 'BATCH_INVITE' : ['/cgi-bin/batch/invite?access_token=ACCESS_TOKEN', 'POST'], + + 'AGENT_GET' : ['/cgi-bin/agent/get?access_token=ACCESS_TOKEN', 'GET'], + 'AGENT_SET' : ['/cgi-bin/agent/set?access_token=ACCESS_TOKEN', 'POST'], + 'AGENT_GET_LIST' : ['/cgi-bin/agent/list?access_token=ACCESS_TOKEN', 'GET'], + + 'MENU_CREATE' : ['/cgi-bin/menu/create?access_token=ACCESS_TOKEN', 'POST'], ## TODO + 'MENU_GET' : ['/cgi-bin/menu/get?access_token=ACCESS_TOKEN', 'GET'], + 'MENU_DELETE' : ['/cgi-bin/menu/delete?access_token=ACCESS_TOKEN', 'GET'], + + 'MESSAGE_SEND' : ['/cgi-bin/message/send?access_token=ACCESS_TOKEN', 'POST'], + 'MESSAGE_REVOKE' : ['/cgi-bin/message/revoke?access_token=ACCESS_TOKEN', 'POST'], + + 'MEDIA_GET' : ['/cgi-bin/media/get?access_token=ACCESS_TOKEN', 'GET'], + + 'GET_USER_INFO_BY_CODE' : ['/cgi-bin/user/getuserinfo?access_token=ACCESS_TOKEN', 'GET'], + 'GET_USER_DETAIL' : ['/cgi-bin/user/getuserdetail?access_token=ACCESS_TOKEN', 'POST'], + + 'GET_TICKET' : ['/cgi-bin/ticket/get?access_token=ACCESS_TOKEN', 'GET'], + 'GET_JSAPI_TICKET' : ['/cgi-bin/get_jsapi_ticket?access_token=ACCESS_TOKEN', 'GET'], + + 'GET_CHECKIN_OPTION' : ['/cgi-bin/checkin/getcheckinoption?access_token=ACCESS_TOKEN', 'POST'], + 'GET_CHECKIN_DATA' : ['/cgi-bin/checkin/getcheckindata?access_token=ACCESS_TOKEN', 'POST'], + 'GET_APPROVAL_DATA': ['/cgi-bin/corp/getapprovaldata?access_token=ACCESS_TOKEN', 'POST'], + + 'GET_INVOICE_INFO' : ['/cgi-bin/card/invoice/reimburse/getinvoiceinfo?access_token=ACCESS_TOKEN', 'POST'], + 'UPDATE_INVOICE_STATUS' : + ['/cgi-bin/card/invoice/reimburse/updateinvoicestatus?access_token=ACCESS_TOKEN', 'POST'], + 'BATCH_UPDATE_INVOICE_STATUS' : + ['/cgi-bin/card/invoice/reimburse/updatestatusbatch?access_token=ACCESS_TOKEN', 'POST'], + 'BATCH_GET_INVOICE_INFO' : + ['/cgi-bin/card/invoice/reimburse/getinvoiceinfobatch?access_token=ACCESS_TOKEN', 'POST'], + + 'APP_CHAT_CREATE' : ['/cgi-bin/appchat/create?access_token=ACCESS_TOKEN', 'POST'], + 'APP_CHAT_GET' : ['/cgi-bin/appchat/get?access_token=ACCESS_TOKEN', 'GET'], + 'APP_CHAT_UPDATE' : ['/cgi-bin/appchat/update?access_token=ACCESS_TOKEN', 'POST'], + 'APP_CHAT_SEND' : ['/cgi-bin/appchat/send?access_token=ACCESS_TOKEN', 'POST'], + + 'MINIPROGRAM_CODE_TO_SESSION_KEY' : ['/cgi-bin/miniprogram/jscode2session?access_token=ACCESS_TOKEN', 'GET'], +} + +class CorpApi(AbstractApi) : + def __init__(self, corpid, secret) : + self.corpid = corpid + self.secret = secret + self.access_token = None + + def getAccessToken(self) : + if self.access_token is None : + self.refreshAccessToken() + return self.access_token + + def refreshAccessToken(self) : + response = self.httpCall( + CORP_API_TYPE['GET_ACCESS_TOKEN'], + { + 'corpid' : self.corpid, + 'corpsecret': self.secret, + }) + self.access_token = response.get('access_token') + diff --git a/thirdpart/weworkapi/__init__.py b/thirdpart/weworkapi/__init__.py new file mode 100644 index 0000000000..009b3a2881 --- /dev/null +++ b/thirdpart/weworkapi/__init__.py @@ -0,0 +1 @@ +"""ref: https://github.com/sbzhu/weworkapi_python"""