1
0
mirror of https://github.com/haiwen/seahub.git synced 2025-09-01 15:09:14 +00:00

[social auth & notification] Add wechat work notification

This commit is contained in:
zhengxie
2018-11-14 16:04:50 +08:00
parent 57fac87e44
commit 0efcbb10a3
45 changed files with 2589 additions and 3 deletions

View File

@@ -20,3 +20,4 @@ gunicorn==19.8.1
django-webpack-loader==0.6.0
git+git://github.com/haiwen/python-cas.git@ffc49235fd7cc32c4fdda5acfa3707e1405881df#egg=python_cas
futures==3.2.0
social-auth-core==1.7.0

View File

@@ -0,0 +1,164 @@
# Copyright (c) 2012-2016 Seafile Ltd.
# encoding: utf-8
from datetime import datetime
import logging
import re
from django.conf import settings
from django.core.management.base import BaseCommand
from django.core.urlresolvers import reverse
from django.utils import translation
from django.utils.translation import ungettext
from social_django.models import UserSocialAuth
from weworkapi import CorpApi
from seahub.base.models import CommandsLastCheck
from seahub.notifications.models import UserNotification
from seahub.profile.models import Profile
from seahub.utils import get_site_scheme_and_netloc, get_site_name
# Get an instance of a logger
logger = logging.getLogger(__name__)
########## Utility Functions ##########
def wrap_div(s):
"""
Replace <a ..>xx</a> to xx and wrap content with <div></div>.
"""
patt = '<a.*?>(.+?)</a>'
def repl(matchobj):
return matchobj.group(1)
return '<div class="highlight">' + re.sub(patt, repl, s) + '</div>'
class CommandLogMixin(object):
def println(self, msg):
self.stdout.write('[%s] %s\n' % (str(datetime.now()), msg))
def log_error(self, msg):
logger.error(msg)
self.println(msg)
def log_info(self, msg):
logger.info(msg)
self.println(msg)
def log_debug(self, msg):
logger.debug(msg)
self.println(msg)
#######################################
class Command(BaseCommand, CommandLogMixin):
help = 'Send WeChat Work msg to user if he/she has unseen notices every '
'period of time.'
label = "notifications_send_wxwork_notices"
def handle(self, *args, **options):
self.log_debug('Start sending WeChat Work msg...')
self.api = CorpApi.CorpApi(settings.SOCIAL_AUTH_WEIXIN_WORK_KEY,
settings.SOCIAL_AUTH_WEIXIN_WORK_SECRET)
self.do_action()
self.log_debug('Finish sending WeChat Work msg.\n')
def send_wx_msg(self, uid, title, content, detail_url):
try:
self.log_info('Send wechat msg to user: %s, msg: %s' % (uid, content))
response = self.api.httpCall(
CorpApi.CORP_API_TYPE['MESSAGE_SEND'],
{
"touser": uid,
"agentid": settings.SOCIAL_AUTH_WEIXIN_WORK_AGENTID,
'msgtype': 'textcard',
# 'climsgid': 'climsgidclimsgid_d',
'textcard': {
'title': title,
'description': content,
'url': detail_url,
},
'safe': 0,
})
self.log_info(response)
except Exception as ex:
logger.error(ex, exc_info=True)
def get_user_language(self, username):
return Profile.objects.get_user_language(username)
def do_action(self):
now = datetime.now()
today = datetime.now().replace(hour=0).replace(minute=0).replace(
second=0).replace(microsecond=0)
# 1. get all users who are connected wechat work
socials = UserSocialAuth.objects.filter(provider='weixin-work')
users = [(x.username, x.uid) for x in socials]
if not users:
return
user_uid_map = {}
for username, uid in users:
user_uid_map[username] = uid
# 2. get previous time that command last runs
try:
cmd_last_check = CommandsLastCheck.objects.get(command_type=self.label)
self.log_debug('Last check time is %s' % cmd_last_check.last_check)
last_check_dt = cmd_last_check.last_check
cmd_last_check.last_check = now
cmd_last_check.save()
except CommandsLastCheck.DoesNotExist:
last_check_dt = today
self.log_debug('Create new last check time: %s' % now)
CommandsLastCheck(command_type=self.label, last_check=now).save()
# 3. get all unseen notices for those users
qs = UserNotification.objects.filter(
timestamp__gt=last_check_dt
).filter(seen=False).filter(
to_user__in=user_uid_map.keys()
)
user_notices = {}
for q in qs:
if q.to_user not in user_notices:
user_notices[q.to_user] = [q]
else:
user_notices[q.to_user].append(q)
# 4. send msg to users
url = get_site_scheme_and_netloc().rstrip('/') + reverse('user_notification_list')
for username, uid in users:
notices = user_notices.get(username, [])
count = len(notices)
if count == 0:
continue
# save current language
cur_language = translation.get_language()
# get and active user language
user_language = self.get_user_language(username)
translation.activate(user_language)
self.log_debug('Set language code to %s for user: %s' % (
user_language, username))
title = ungettext(
"\n"
"You've got 1 new notice on %(site_name)s:\n",
"\n"
"You've got %(num)s new notices on %(site_name)s:\n",
count
) % {
'num': count,
'site_name': get_site_name(),
}
content = ''.join([wrap_div(x.format_msg()) for x in notices])
self.send_wx_msg(uid, title, content, url)
translation.activate(cur_language)

View File

@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.15 on 2018-11-15 08:25
from __future__ import unicode_literals
import datetime
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('notifications', '0002_auto_20180426_0710'),
]
operations = [
migrations.AlterField(
model_name='usernotification',
name='timestamp',
field=models.DateTimeField(db_index=True, default=datetime.datetime.now),
),
]

View File

@@ -333,7 +333,7 @@ class UserNotification(models.Model):
to_user = LowerCaseCharField(db_index=True, max_length=255)
msg_type = models.CharField(db_index=True, max_length=30)
detail = models.TextField()
timestamp = models.DateTimeField(default=datetime.datetime.now)
timestamp = models.DateTimeField(db_index=True, default=datetime.datetime.now)
seen = models.BooleanField('seen', default=False)
objects = UserNotificationManager()
@@ -487,6 +487,32 @@ class UserNotification(models.Model):
return {'message': message, 'msg_from': msg_from}
########## functions used in templates
def format_msg(self):
if self.is_group_msg():
return self.format_group_message_title()
elif self.is_file_uploaded_msg():
return self.format_file_uploaded_msg()
elif self.is_repo_share_msg():
return self.format_repo_share_msg()
elif self.is_repo_share_to_group_msg():
return self.format_repo_share_to_group_msg()
elif self.is_group_join_request():
return self.format_group_join_request()
elif self.is_file_comment_msg():
return self.format_file_comment_msg()
elif self.is_review_comment_msg():
return self.format_review_comment_msg()
elif self.is_update_review_msg():
return self.format_update_review_msg()
elif self.is_request_reviewer_msg():
return self.format_request_reviewer_msg()
elif self.is_guest_invitation_accepted_msg():
return self.format_guest_invitation_accepted_msg()
elif self.is_add_user_to_group():
return self.format_add_user_to_group()
else:
return ''
def format_file_uploaded_msg(self):
"""

View File

@@ -175,6 +175,29 @@
</div>
{% endif %}
<div class="setting-item" id="social-auth">
<h3>{% trans "Social Login" %}</h3>
<ul>
<li>
{% if request.LANGUAGE_CODE == 'zh-cn' %}
企业微信
{% else %}
WeChat Work
{% endif %}
{% if social_connected %}
<a class="social-disconnect" href="#" data-url="{% url "social:disconnect" 'weixin-work' %}?next={{ social_next_page }}">{% trans "Disconnect" %}</a>
{% else %}
<a href="{% url "social:begin" 'weixin-work' %}?next={{ social_next_page }}">{% trans "Connect" %}</a>
{% endif %}
</li>
</ul>
</div>
{% if ENABLE_DELETE_ACCOUNT %}
<div class="setting-item" id="del-account">
<h3>{% trans "Delete Account" %}</h3>
@@ -379,5 +402,11 @@ $('#set-email-notice-interval-form').on('submit', function() {
});
return false;
});
addConfirmTo($('.social-disconnect'), {
'title':"{% trans "Disconnect" %}",
'con':"{% trans "Are you sure you want to disconnect?" %}",
'post':true
});
</script>
{% endblock %}

View File

@@ -86,6 +86,10 @@ def edit_profile(request):
email_inverval = UserOptions.objects.get_file_updates_email_interval(username)
email_inverval = email_inverval if email_inverval is not None else 0
from social_django.models import UserSocialAuth
social_connected = UserSocialAuth.objects.filter(
username=request.user.username, provider='weixin-work').count() > 0
resp_dict = {
'form': form,
'server_crypto': server_crypto,
@@ -102,6 +106,8 @@ def edit_profile(request):
'ENABLE_UPDATE_USER_INFO': ENABLE_UPDATE_USER_INFO,
'webdav_passwd': webdav_passwd,
'email_notification_interval': email_inverval,
'social_connected': social_connected,
'social_next_page': reverse('edit_profile'),
}
if has_two_factor_auth():

View File

@@ -125,6 +125,7 @@ MIDDLEWARE_CLASSES = (
'seahub.two_factor.middleware.OTPMiddleware',
'seahub.two_factor.middleware.ForceTwoFactorAuthMiddleware',
'seahub.trusted_ip.middleware.LimitIpMiddleware',
'social_django.middleware.SocialAuthExceptionMiddleware',
)
@@ -152,6 +153,9 @@ TEMPLATES = [
'django.template.context_processors.request',
'django.contrib.messages.context_processors.messages',
'social_django.context_processors.backends',
'social_django.context_processors.login_redirect',
'seahub.auth.context_processors.auth',
'seahub.base.context_processors.base',
'seahub.base.context_processors.debug',
@@ -223,6 +227,7 @@ INSTALLED_APPS = (
'post_office',
'termsandconditions',
'webpack_loader',
'social_django',
'seahub.api2',
'seahub.avatar',
@@ -264,17 +269,42 @@ CONSTANCE_BACKEND = 'constance.backends.database.DatabaseBackend'
CONSTANCE_DATABASE_CACHE_BACKEND = 'default'
AUTHENTICATION_BACKENDS = (
'seahub.social_core.backends.weixin_enterprise.WeixinWorkOAuth2',
'seahub.base.accounts.AuthBackend',
'seahub.oauth.backends.OauthRemoteUserBackend',
)
SOCIAL_AUTH_URL_NAMESPACE = 'social'
SOCIAL_AUTH_VERIFY_SSL = True
SOCIAL_AUTH_METHODS = (
('weixin', 'WeChat'),
)
SOCIAL_AUTH_WEIXIN_WORK_AGENTID = ''
SOCIAL_AUTH_WEIXIN_WORK_KEY = ''
SOCIAL_AUTH_WEIXIN_WORK_SECRET = ''
SOCIAL_AUTH_PIPELINE = (
'social_core.pipeline.social_auth.social_details',
'social_core.pipeline.social_auth.social_uid',
'social_core.pipeline.social_auth.auth_allowed',
'seahub.social_core.pipeline.social_auth.social_user',
'seahub.social_core.pipeline.user.get_username',
'seahub.social_core.pipeline.user.create_user',
'seahub.social_core.pipeline.social_auth.associate_user',
'social_core.pipeline.social_auth.load_extra_data',
# 'social_core.pipeline.user.user_details',
'seahub.social_core.pipeline.user.save_profile',
)
ENABLE_OAUTH = False
ENABLE_WATERMARK = False
# allow user to clean library trash
ENABLE_USER_CLEAN_TRASH = True
LOGIN_REDIRECT_URL = '/profile/'
LOGIN_REDIRECT_URL = '/'
LOGIN_URL = '/accounts/login/'
LOGOUT_URL = '/accounts/logout/'
LOGOUT_REDIRECT_URL = None

View File

View File

View File

@@ -0,0 +1,198 @@
import urllib
from requests import HTTPError
from django.conf import settings
from social_core.backends.oauth import BaseOAuth2
from social_core.exceptions import AuthCanceled, AuthUnknownError
import logging
logger = logging.getLogger(__name__)
try:
WEIXIN_WORK_SP = True if settings.SOCIAL_AUTH_WEIXIN_WORK_SUITID else False
except AttributeError:
WEIXIN_WORK_SP = False
if WEIXIN_WORK_SP is True:
_AUTHORIZATION_URL = 'https://open.work.weixin.qq.com/wwopen/sso/3rd_qrConnect'
_ACCESS_TOKEN_URL = 'https://qyapi.weixin.qq.com/cgi-bin/service/get_provider_token'
_USER_INFO_URL = 'https://qyapi.weixin.qq.com/cgi-bin/service/get_login_info'
else:
_AUTHORIZATION_URL = 'https://open.work.weixin.qq.com/wwopen/sso/qrConnect'
_ACCESS_TOKEN_URL = 'https://qyapi.weixin.qq.com/cgi-bin/token'
_USER_INFO_URL = 'https://qyapi.weixin.qq.com/cgi-bin/user/getuserinfo'
class WeixinWorkOAuth2(BaseOAuth2):
"""WeChat Work OAuth authentication backend"""
name = 'weixin-work'
ID_KEY = 'UserId'
AUTHORIZATION_URL = _AUTHORIZATION_URL
ACCESS_TOKEN_URL = _ACCESS_TOKEN_URL
ACCESS_TOKEN_METHOD = 'POST'
DEFAULT_SCOPE = ['snsapi_login']
REDIRECT_STATE = False
EXTRA_DATA = [
('nickname', 'username'),
('headimgurl', 'profile_image_url'),
]
def extra_data(self, user, uid, response, details=None, *args, **kwargs):
data = super(BaseOAuth2, self).extra_data(user, uid, response,
details=details,
*args, **kwargs)
if WEIXIN_WORK_SP:
data['corp_info'] = response.get('corp_info')
data['user_info'] = response.get('user_info')
return data
def get_user_id(self, details, response):
"""Return a unique ID for the current user, by default from server
response."""
if WEIXIN_WORK_SP:
return response.get('user_info').get('userid')
else:
return response.get(self.ID_KEY)
def get_user_details(self, response):
"""Return user details from Weixin. API URL is:
https://api.weixin.qq.com/sns/userinfo
"""
if WEIXIN_WORK_SP:
user_info = response.get('user_info')
return {
'userid': user_info.get('userid'),
'user_name': user_info.get('name'),
'user_avatar': user_info.get('avatar'),
'corpid': response.get('corp_info').get('corpid'),
}
else:
if self.setting('DOMAIN_AS_USERNAME'):
username = response.get('domain', '')
else:
username = response.get('nickname', '')
return {
'username': username,
'profile_image_url': response.get('headimgurl', '')
}
def user_data(self, access_token, *args, **kwargs):
if WEIXIN_WORK_SP:
data = self.get_json(_USER_INFO_URL,
params={'access_token': access_token},
json={'auth_code': kwargs['request'].GET.get('auth_code')},
headers={'Content-Type': 'application/json',
'Accept': 'application/json'},
method='post')
else:
data = self.get_json(_USER_INFO_URL, params={
'access_token': access_token,
'code': kwargs['request'].GET.get('code')
})
nickname = data.get('nickname')
if nickname:
# weixin api has some encode bug, here need handle
data['nickname'] = nickname.encode(
'raw_unicode_escape'
).decode('utf-8')
return data
def auth_params(self, state=None):
appid, secret = self.get_key_and_secret()
if WEIXIN_WORK_SP:
params = {
'appid': appid,
'redirect_uri': self.get_redirect_uri(state),
'usertype': 'member',
}
else:
params = {
'appid': appid,
'redirect_uri': self.get_redirect_uri(state),
'agentid': self.setting('AGENTID'),
}
if self.STATE_PARAMETER and state:
params['state'] = state
if self.RESPONSE_TYPE:
params['response_type'] = self.RESPONSE_TYPE
return params
def auth_complete_params(self, state=None):
appid, secret = self.get_key_and_secret()
if WEIXIN_WORK_SP is True:
return {
'corpid': appid,
'provider_secret': secret,
}
return {
'grant_type': 'authorization_code', # request auth code
'code': self.data.get('code', ''), # server response code
'appid': appid,
'secret': secret,
'redirect_uri': self.get_redirect_uri(state),
}
def refresh_token_params(self, token, *args, **kwargs):
appid, secret = self.get_key_and_secret()
return {
'refresh_token': token,
'grant_type': 'refresh_token',
'appid': appid,
'secret': secret
}
def access_token_url(self, appid, secret):
if WEIXIN_WORK_SP:
return self.ACCESS_TOKEN_URL
else:
return self.ACCESS_TOKEN_URL + '?corpid=%s&corpsecret=%s' % (appid, secret)
def auth_complete(self, *args, **kwargs):
"""Completes loging process, must return user instance"""
self.process_error(self.data)
appid, secret = self.get_key_and_secret()
try:
if WEIXIN_WORK_SP:
response = self.request_access_token(
self.access_token_url(appid, secret),
json=self.auth_complete_params(self.validate_state()),
headers={'Content-Type': 'application/json',
'Accept': 'application/json'},
method=self.ACCESS_TOKEN_METHOD
)
else:
response = self.request_access_token(
self.access_token_url(appid, secret),
data=self.auth_complete_params(self.validate_state()),
headers=self.auth_headers(),
method=self.ACCESS_TOKEN_METHOD
)
except HTTPError as err:
if err.response.status_code == 400:
raise AuthCanceled(self, response=err.response)
else:
raise
except KeyError:
raise AuthUnknownError(self)
try:
if response['errmsg'] != 'ok':
raise AuthCanceled(self)
except KeyError:
pass # assume response is ok if 'errmsg' key not found
self.process_error(response)
access_token = response['provider_access_token'] if WEIXIN_WORK_SP else response['access_token']
return self.do_auth(access_token, response=response,
*args, **kwargs)

View File

View File

@@ -0,0 +1,34 @@
from social_core.exceptions import AuthAlreadyAssociated
def social_user(backend, uid, user=None, *args, **kwargs):
provider = backend.name
social = backend.strategy.storage.user.get_social_auth(provider, uid)
if social:
if user and social.user.username != user.username:
msg = 'This {0} account is already in use.'.format(provider)
raise AuthAlreadyAssociated(backend, msg)
elif not user:
user = social.user
return {'social': social,
'user': user,
'is_new': user is None,
'new_association': social is None}
def associate_user(backend, uid, user=None, social=None, *args, **kwargs):
if user and not social:
try:
social = backend.strategy.storage.user.create_social_auth(
user, uid, backend.name
)
except Exception as err:
if not backend.strategy.storage.is_integrity_error(err):
raise
# Protect for possible race condition, those bastard with FTL
# clicking capabilities, check issue #131:
# https://github.com/omab/django-social-auth/issues/131
return social_user(backend, uid, user, *args, **kwargs)
else:
return {'social': social,
'user': user,
'new_association': True}

View File

@@ -0,0 +1,87 @@
from seahub.profile.models import Profile
from seahub.utils.auth import gen_user_virtual_id
USER_FIELDS = ['username', 'email']
def get_username(strategy, details, backend, user=None, *args, **kwargs):
if 'username' not in backend.setting('USER_FIELDS', USER_FIELDS):
return
storage = strategy.storage
if not user:
final_username = gen_user_virtual_id()
else:
final_username = storage.user.get_username(user)
return {'username': final_username}
def create_user(strategy, details, backend, user=None, *args, **kwargs):
if user:
return {'is_new': False}
fields = dict((name, kwargs.get(name, details.get(name)))
for name in backend.setting('USER_FIELDS', USER_FIELDS))
if not fields:
return
return {
'is_new': True,
'user': strategy.create_user(**fields)
}
def save_profile(strategy, details, backend, user=None, *args, **kwargs):
if not user:
return
email = details.get('email', '')
if email:
Profile.objects.add_or_update(username=user.username,
contact_email=email)
fullname = details.get('fullname', '')
if fullname:
Profile.objects.add_or_update(username=user.username,
nickname=fullname)
# weixin username and profile_image_url
nickname = details.get('username', '')
if nickname:
Profile.objects.add_or_update(username=user.username,
nickname=nickname)
avatar_url = details.get('profile_image_url', '')
if avatar_url:
_update_user_avatar(user, avatar_url)
import os
import logging
import urllib2
from django.core.files import File
from seahub.avatar.models import Avatar
from seahub.avatar.signals import avatar_updated
logger = logging.getLogger(__name__)
def _update_user_avatar(user, pic):
if not pic:
return
logger.info("retrieve pic from %s" % pic)
filedata = urllib2.urlopen(pic)
datatowrite = filedata.read()
filename = '/tmp/%s.jpg' % user.username
with open(filename, 'wb') as f:
f.write(datatowrite)
logger.info("save pic to %s" % filename)
avatar = Avatar(emailuser=user.username, primary=True)
avatar.avatar.save(
'image.jpg', File(open(filename))
)
avatar.save()
avatar_updated.send(sender=Avatar, user=user, avatar=avatar)
os.remove(filename)

View File

@@ -0,0 +1,275 @@
#!/usr/bin/env python
#-*- encoding:utf-8 -*-
""" 对企业微信发送给企业后台的消息加解密示例代码.
@copyright: Copyright (c) 1998-2014 Tencent Inc.
"""
# ------------------------------------------------------------------------
import base64
import string
import random
import hashlib
import time
import struct
try:
from Crypto.Cipher import AES
except ImportError:
AES = None
import xml.etree.cElementTree as ET
import sys
import socket
reload(sys)
from . import ierror
sys.setdefaultencoding('utf-8')
"""
关于Crypto.Cipher模块ImportError: No module named 'Crypto'解决方案
请到官方网站 https://www.dlitz.net/software/pycrypto/ 下载pycrypto。
下载后按照README中的“Installation”小节的提示进行pycrypto安装。
"""
class FormatException(Exception):
pass
def throw_exception(message, exception_class=FormatException):
"""my define raise exception function"""
raise exception_class(message)
class SHA1:
"""计算企业微信的消息签名接口"""
def getSHA1(self, token, timestamp, nonce, encrypt):
"""用SHA1算法生成安全签名
@param token: 票据
@param timestamp: 时间戳
@param encrypt: 密文
@param nonce: 随机字符串
@return: 安全签名
"""
try:
sortlist = [token, timestamp, nonce, encrypt]
sortlist.sort()
sha = hashlib.sha1()
sha.update("".join(sortlist))
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
except Exception,e:
print e
return ierror.WXBizMsgCrypt_ComputeSignature_Error, None
class XMLParse:
"""提供提取消息格式中的密文及生成回复消息格式的接口"""
# xml消息模板
AES_TEXT_RESPONSE_TEMPLATE = """<xml>
<Encrypt><![CDATA[%(msg_encrypt)s]]></Encrypt>
<MsgSignature><![CDATA[%(msg_signaturet)s]]></MsgSignature>
<TimeStamp>%(timestamp)s</TimeStamp>
<Nonce><![CDATA[%(nonce)s]]></Nonce>
</xml>"""
def extract(self, xmltext):
"""提取出xml数据包中的加密消息
@param xmltext: 待提取的xml字符串
@return: 提取出的加密消息字符串
"""
try:
xml_tree = ET.fromstring(xmltext)
encrypt = xml_tree.find("Encrypt")
return ierror.WXBizMsgCrypt_OK, encrypt.text
except Exception,e:
print e
return ierror.WXBizMsgCrypt_ParseXml_Error,None,None
def generate(self, encrypt, signature, timestamp, nonce):
"""生成xml消息
@param encrypt: 加密后的消息密文
@param signature: 安全签名
@param timestamp: 时间戳
@param nonce: 随机字符串
@return: 生成的xml字符串
"""
resp_dict = {
'msg_encrypt' : encrypt,
'msg_signaturet': signature,
'timestamp' : timestamp,
'nonce' : nonce,
}
resp_xml = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict
return resp_xml
class PKCS7Encoder():
"""提供基于PKCS7算法的加解密接口"""
block_size = 32
def encode(self, text):
""" 对需要加密的明文进行填充补位
@param text: 需要进行填充补位操作的明文
@return: 补齐明文字符串
"""
text_length = len(text)
# 计算需要填充的位数
amount_to_pad = self.block_size - (text_length % self.block_size)
if amount_to_pad == 0:
amount_to_pad = self.block_size
# 获得补位所用的字符
pad = chr(amount_to_pad)
return text + pad * amount_to_pad
def decode(self, decrypted):
"""删除解密后明文的补位字符
@param decrypted: 解密后的明文
@return: 删除补位字符后的明文
"""
pad = ord(decrypted[-1])
if pad<1 or pad >32:
pad = 0
return decrypted[:-pad]
class Prpcrypt(object):
"""提供接收和推送给企业微信消息的加解密接口"""
def __init__(self,key):
#self.key = base64.b64decode(key+"=")
self.key = key
# 设置加解密模式为AES的CBC模式
self.mode = AES.MODE_CBC
def encrypt(self,text,receiveid):
"""对明文进行加密
@param text: 需要加密的明文
@return: 加密得到的字符串
"""
# 16位随机字符串添加到明文开头
text = self.get_random_str() + struct.pack("I",socket.htonl(len(text))) + text + receiveid
# 使用自定义的填充方式对明文进行补位填充
pkcs7 = PKCS7Encoder()
text = pkcs7.encode(text)
# 加密
cryptor = AES.new(self.key,self.mode,self.key[:16])
try:
ciphertext = cryptor.encrypt(text)
# 使用BASE64对加密后的字符串进行编码
return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext)
except Exception,e:
print e
return ierror.WXBizMsgCrypt_EncryptAES_Error,None
def decrypt(self,text,receiveid):
"""对解密后的明文进行补位删除
@param text: 密文
@return: 删除填充补位后的明文
"""
try:
cryptor = AES.new(self.key,self.mode,self.key[:16])
# 使用BASE64对密文进行解码然后AES-CBC解密
plain_text = cryptor.decrypt(base64.b64decode(text))
except Exception,e:
print e
return ierror.WXBizMsgCrypt_DecryptAES_Error,None
try:
pad = ord(plain_text[-1])
# 去掉补位字符串
#pkcs7 = PKCS7Encoder()
#plain_text = pkcs7.encode(plain_text)
# 去除16位随机字符串
content = plain_text[16:-pad]
xml_len = socket.ntohl(struct.unpack("I",content[ : 4])[0])
xml_content = content[4 : xml_len+4]
from_receiveid = content[xml_len+4:]
except Exception,e:
print e
return ierror.WXBizMsgCrypt_IllegalBuffer,None
if from_receiveid != receiveid:
return ierror.WXBizMsgCrypt_ValidateCorpid_Error,None
return 0,xml_content
def get_random_str(self):
""" 随机生成16位字符串
@return: 16位字符串
"""
rule = string.letters + string.digits
str = random.sample(rule, 16)
return "".join(str)
class WXBizMsgCrypt(object):
#构造函数
def __init__(self,sToken,sEncodingAESKey,sReceiveId):
try:
self.key = base64.b64decode(sEncodingAESKey+"=")
assert len(self.key) == 32
except:
throw_exception("[error]: EncodingAESKey unvalid !", FormatException)
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
self.m_sToken = sToken
self.m_sReceiveId = sReceiveId
#验证URL
#@param sMsgSignature: 签名串对应URL参数的msg_signature
#@param sTimeStamp: 时间戳对应URL参数的timestamp
#@param sNonce: 随机串对应URL参数的nonce
#@param sEchoStr: 随机串对应URL参数的echostr
#@param sReplyEchoStr: 解密之后的echostr当return返回0时有效
#@return成功0失败返回对应的错误码
def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr):
sha1 = SHA1()
ret,signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr)
if ret != 0:
return ret, None
if not signature == sMsgSignature:
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
pc = Prpcrypt(self.key)
ret,sReplyEchoStr = pc.decrypt(sEchoStr,self.m_sReceiveId)
return ret,sReplyEchoStr
def EncryptMsg(self, sReplyMsg, sNonce, timestamp = None):
#将企业回复用户的消息加密打包
#@param sReplyMsg: 企业号待回复用户的消息xml格式的字符串
#@param sTimeStamp: 时间戳可以自己生成也可以用URL参数的timestamp,如为None则自动用当前时间
#@param sNonce: 随机串可以自己生成也可以用URL参数的nonce
#sEncryptMsg: 加密后的可以直接回复用户的密文包括msg_signature, timestamp, nonce, encrypt的xml格式的字符串,
#return成功0sEncryptMsg,失败返回对应的错误码None
pc = Prpcrypt(self.key)
ret,encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId)
if ret != 0:
return ret,None
if timestamp is None:
timestamp = str(int(time.time()))
# 生成安全签名
sha1 = SHA1()
ret,signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt)
if ret != 0:
return ret,None
xmlParse = XMLParse()
return ret,xmlParse.generate(encrypt, signature, timestamp, sNonce)
def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce):
# 检验消息的真实性,并且获取解密后的明文
# @param sMsgSignature: 签名串对应URL参数的msg_signature
# @param sTimeStamp: 时间戳对应URL参数的timestamp
# @param sNonce: 随机串对应URL参数的nonce
# @param sPostData: 密文对应POST请求的数据
# xml_content: 解密后的原文当return返回0时有效
# @return: 成功0失败返回对应的错误码
# 验证安全签名
xmlParse = XMLParse()
ret,encrypt = xmlParse.extract(sPostData)
if ret != 0:
return ret, None
sha1 = SHA1()
ret,signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt)
if ret != 0:
return ret, None
if not signature == sMsgSignature:
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
pc = Prpcrypt(self.key)
ret,xml_content = pc.decrypt(encrypt,self.m_sReceiveId)
return ret,xml_content

View File

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#########################################################################
# Author: jonyqin
# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
# File Name: ierror.py
# Description:定义错误码含义
#########################################################################
WXBizMsgCrypt_OK = 0
WXBizMsgCrypt_ValidateSignature_Error = -40001
WXBizMsgCrypt_ParseXml_Error = -40002
WXBizMsgCrypt_ComputeSignature_Error = -40003
WXBizMsgCrypt_IllegalAesKey = -40004
WXBizMsgCrypt_ValidateCorpid_Error = -40005
WXBizMsgCrypt_EncryptAES_Error = -40006
WXBizMsgCrypt_DecryptAES_Error = -40007
WXBizMsgCrypt_IllegalBuffer = -40008
WXBizMsgCrypt_EncodeBase64_Error = -40009
WXBizMsgCrypt_DecodeBase64_Error = -40010
WXBizMsgCrypt_GenReturnXml_Error = -40011

145
seahub/social_core/views.py Normal file
View File

@@ -0,0 +1,145 @@
import logging
import requests
from django.conf import settings
from django.core.cache import cache
from django.core.urlresolvers import reverse
from django.http import HttpResponse, HttpResponseRedirect
from django.views.decorators.csrf import csrf_exempt
from django.utils.http import urlquote
from seahub.social_core.utils.WXBizMsgCrypt import WXBizMsgCrypt
from seahub.utils.urls import abs_reverse
# Get an instance of a logger
logger = logging.getLogger(__name__)
@csrf_exempt
def weixin_work_cb(request):
"""Callback for weixin work provider API.
Used in callback config at app details page.
e.g. https://open.work.weixin.qq.com/wwopen/developer#/sass/apps/detail/ww24c53566499d354f
ref: https://work.weixin.qq.com/api/doc#90001/90143/91116
"""
token = settings.SOCIAL_AUTH_WEIXIN_WORK_TOKEN
EncodingAESKey = settings.SOCIAL_AUTH_WEIXIN_WORK_AES_KEY
msg_signature = request.GET.get('msg_signature', None)
timestamp = request.GET.get('timestamp', None)
nonce = request.GET.get('nonce', None)
if not (msg_signature and timestamp and nonce):
assert False, 'Request Error'
if request.method == 'GET':
wxcpt = WXBizMsgCrypt(token, EncodingAESKey,
settings.SOCIAL_AUTH_WEIXIN_WORK_KEY)
echostr = request.GET.get('echostr', '')
ret, decoded_echostr = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if ret != 0:
assert False, 'Verify Error'
return HttpResponse(decoded_echostr)
elif request.method == 'POST':
wxcpt = WXBizMsgCrypt(token, EncodingAESKey,
settings.SOCIAL_AUTH_WEIXIN_WORK_SUITID)
ret, xml_msg = wxcpt.DecryptMsg(request.body, msg_signature, timestamp, nonce)
if ret != 0:
assert False, 'Decrypt Error'
import xml.etree.cElementTree as ET
xml_tree = ET.fromstring(xml_msg)
suite_ticket = xml_tree.find("SuiteTicket").text
logger.info('suite ticket: %s' % suite_ticket)
# TODO: use persistent store
cache.set('wx_work_suite_ticket', suite_ticket, 3600)
return HttpResponse('success')
def _get_suite_access_token():
suite_access_token = cache.get('wx_work_suite_access_token', None)
if suite_access_token:
return suite_access_token
suite_ticket = cache.get('wx_work_suite_ticket', None)
if not suite_ticket:
assert False, 'suite ticket is None!'
get_suite_token_url = 'https://qyapi.weixin.qq.com/cgi-bin/service/get_suite_token'
resp = requests.request(
'POST', get_suite_token_url,
json={
"suite_id": settings.SOCIAL_AUTH_WEIXIN_WORK_SUITID,
"suite_secret": settings.SOCIAL_AUTH_WEIXIN_WORK_SUIT_SECRET,
"suite_ticket": suite_ticket,
},
headers={'Content-Type': 'application/json',
'Accept': 'application/json'},
)
suite_access_token = resp.json().get('suite_access_token', None)
if not suite_access_token:
logger.error('Failed to get suite_access_token!')
logger.error(resp.content)
assert False, 'suite_access_token is None!'
else:
cache.set('wx_work_suite_access_token', suite_access_token, 3600)
return suite_access_token
def weixin_work_3rd_app_install(request):
"""Redirect user to weixin work 3rd app install page.
"""
# 0. get suite access token
suite_access_token = _get_suite_access_token()
print('suite access token', suite_access_token)
# 1. get pre_auth_code
get_pre_auth_code_url = 'https://qyapi.weixin.qq.com/cgi-bin/service/get_pre_auth_code?suite_access_token=' + suite_access_token
resp = requests.request('GET', get_pre_auth_code_url)
pre_auth_code = resp.json().get('pre_auth_code', None)
if not pre_auth_code:
logger.error('Failed to get pre_auth_code')
logger.error(resp.content)
assert False, 'pre_auth_code is None'
# 2. set session info
# ref: https://work.weixin.qq.com/api/doc#90001/90143/90602
url = 'https://qyapi.weixin.qq.com/cgi-bin/service/set_session_info?suite_access_token=' + suite_access_token
resp = requests.request(
'POST', url,
json={
"pre_auth_code": pre_auth_code,
"session_info":
{
"appid": [],
"auth_type": 1 # TODO: 0: production; 1: testing.
}
},
headers={'Content-Type': 'application/json',
'Accept': 'application/json'},
)
# TODO: use random state
url = 'https://open.work.weixin.qq.com/3rdapp/install?suite_id=%s&pre_auth_code=%s&redirect_uri=%s&state=STATE123' % (
settings.SOCIAL_AUTH_WEIXIN_WORK_SUITID,
pre_auth_code,
abs_reverse('weixin_work_3rd_app_install_cb'),
)
return HttpResponseRedirect(url)
@csrf_exempt
def weixin_work_3rd_app_install_cb(request):
"""Callback for weixin work 3rd app install API.
https://work.weixin.qq.com/api/doc#90001/90143/90597
"""
# TODO: check state
pass

View File

@@ -137,6 +137,7 @@ urlpatterns = [
url(r'^sso/$', sso, name='sso'),
url(r'^shib-login/', shib_login, name="shib_login"),
url(r'^oauth/', include('seahub.oauth.urls')),
url(r'^social/', include('social_django.urls', namespace='social')),
url(r'^$', libraries, name='libraries'),
#url(r'^home/$', direct_to_template, { 'template': 'home.html' } ),
@@ -693,3 +694,15 @@ if getattr(settings, 'ENABLE_CAS', False):
url(r'^accounts/cas-logout/$', cas_logout, name='cas_ng_logout'),
url(r'^accounts/cas-callback/$', cas_callback, name='cas_ng_proxy_callback'),
]
from seahub.social_core.views import (
weixin_work_cb, weixin_work_3rd_app_install, weixin_work_3rd_app_install_cb
)
urlpatterns += [
url(r'^weixin-work/callback/$', weixin_work_cb),
url(r'^weixin-work/3rd-app-install/$', weixin_work_3rd_app_install),
url(r'^weixin-work/3rd-app-install/callback/$',
weixin_work_3rd_app_install_cb, name='weixin_work_3rd_app_install_cb'),
]

View File

@@ -1,5 +1,6 @@
import os
from seahub.settings import LOGIN_BG_IMAGE_PATH, MEDIA_ROOT
from seahub.utils import gen_token
def get_login_bg_image_path():
""" Return custom background image path if it exists, otherwise return default background image path.
@@ -15,3 +16,6 @@ def get_custom_login_bg_image_path():
""" Ensure consistency between utils and api.
"""
return 'custom/login-bg.jpg'
def gen_user_virtual_id():
return gen_token(max_length=32) + '@auth.local'

7
seahub/utils/urls.py Normal file
View File

@@ -0,0 +1,7 @@
from django.core.urlresolvers import reverse
from seahub.utils import get_site_scheme_and_netloc
def abs_reverse(viewname, urlconf=None, args=None, kwargs=None, current_app=None):
return get_site_scheme_and_netloc().rstrip('/') + reverse(
viewname, urlconf, args, kwargs, current_app)

View File

View File

@@ -0,0 +1,22 @@
import os
import pytest
from seahub.test_utils import BaseTestCase
TRAVIS = 'TRAVIS' in os.environ
class WeixinWorkCBTest(BaseTestCase):
@pytest.mark.skipif(TRAVIS, reason="This test can only be run in local.")
def test_get(self, ):
resp = self.client.get('/weixin-work/callback/?msg_signature=61a7d120857cdb70d8b936ec5b6e8ed172a41926&timestamp=1543304575&nonce=1542460575&echostr=9uB%2FReg5PQk%2FjzejPjhjWmvKXuxh0R4VK7BJRP62lfRj5kZhuAu0mLMM7hnREJQTJxWWw3Y1BB%2F%2FLkE3V88auA%3D%3D')
assert resp.content == '6819653789729882111'
@pytest.mark.skipif(TRAVIS, reason="This test can only be run in local.")
def test_post(self, ):
data = '<xml><ToUserName><![CDATA[ww24c53566499d354f]]></ToUserName><Encrypt><![CDATA[1fBBPRF7NW4ocCIWFIZK/Pjcn5a0okyx3O8OdbX6Ci2MYq34NaIWuK9jW6dq8pVORvUUsxNP0RVD3vqpq94P932bMyBNKHvFgdn62NaM3vUCSN2SJhwlvNp1KDqMDCX+oiMjcSWJFWXJ0daTpxycSJ88LKH1tA/Z3n18yGq7qs/7qmFJp2kaL6/sb9ATWriA/BCH5UhOaJolqLNm281yAbap+1myr2ELCHPqWz0Gd6Zpvolab6caAp+ivAK5+LohgkrppAjkW7CXI1yM08X0VNArmIT55ZKTFwSW6jeMTBUIIVdYimAKxfxmITxtcu7dVGFQ63hyJTtH6MI0yc7wZRL2ZX9OR5cbO5WTksXv0Rai/3lGSPjThOUS02EI8j4h]]></Encrypt><AgentID><![CDATA[]]></AgentID></xml>'
resp = self.client.post(
'/weixin-work/callback/?msg_signature=a237bf482cc9ae8424010eb63a24859c731b2aa7&timestamp=1543309590&nonce=1542845878',
data=data,
content_type='application/xml',
)

View File

@@ -0,0 +1,23 @@
__version__ = '2.1.0'
from social_core.backends.base import BaseAuth
# django.contrib.auth.load_backend() will import and instanciate the
# authentication backend ignoring the possibility that it might
# require more arguments. Here we set a monkey patch to
# BaseAuth.__init__ to ignore the mandatory strategy argument and load
# it.
def baseauth_init_workaround(original_init):
def fake_init(self, strategy=None, *args, **kwargs):
from .utils import load_strategy
original_init(self, strategy or load_strategy(), *args, **kwargs)
return fake_init
if not getattr(BaseAuth, '__init_patched', False):
BaseAuth.__init__ = baseauth_init_workaround(BaseAuth.__init__)
BaseAuth.__init_patched = True
default_app_config = 'social_django.config.PythonSocialAuthConfig'

View File

@@ -0,0 +1,62 @@
"""Admin settings"""
from itertools import chain
from django.conf import settings
from django.contrib import admin
from social_core.utils import setting_name
from .models import UserSocialAuth, Nonce, Association
class UserSocialAuthOption(admin.ModelAdmin):
"""Social Auth user options"""
list_display = ('user', 'id', 'provider', 'uid')
list_filter = ('provider',)
raw_id_fields = ('user',)
list_select_related = True
def get_search_fields(self, request=None):
search_fields = getattr(
settings, setting_name('ADMIN_USER_SEARCH_FIELDS'), None
)
if search_fields is None:
_User = UserSocialAuth.user_model()
username = getattr(_User, 'USERNAME_FIELD', None) or \
hasattr(_User, 'username') and 'username' or \
None
fieldnames = ('first_name', 'last_name', 'email', username)
all_names = self._get_all_field_names(_User._meta)
search_fields = [name for name in fieldnames
if name and name in all_names]
return ['user__' + name for name in search_fields] + \
getattr(settings, setting_name('ADMIN_SEARCH_FIELDS'), [])
@staticmethod
def _get_all_field_names(model):
names = chain.from_iterable(
(field.name, field.attname)
if hasattr(field, 'attname') else (field.name,)
for field in model.get_fields()
# For complete backwards compatibility, you may want to exclude
# GenericForeignKey from the results.
if not (field.many_to_one and field.related_model is None)
)
return list(set(names))
class NonceOption(admin.ModelAdmin):
"""Nonce options"""
list_display = ('id', 'server_url', 'timestamp', 'salt')
search_fields = ('server_url',)
class AssociationOption(admin.ModelAdmin):
"""Association options"""
list_display = ('id', 'server_url', 'assoc_type')
list_filter = ('assoc_type',)
search_fields = ('server_url',)
admin.site.register(UserSocialAuth, UserSocialAuthOption)
admin.site.register(Nonce, NonceOption)
admin.site.register(Association, AssociationOption)

View File

@@ -0,0 +1,34 @@
# coding=utf-8
import six
import django
from django.db import models
try:
from django.urls import reverse
except ImportError:
from django.core.urlresolvers import reverse
try:
from django.utils.deprecation import MiddlewareMixin
except ImportError:
MiddlewareMixin = object
def get_rel_model(field):
if django.VERSION >= (2, 0):
return field.remote_field.model
user_model = field.rel.to
if isinstance(user_model, six.string_types):
app_label, model_name = user_model.split('.')
user_model = models.get_model(app_label, model_name)
return user_model
def get_request_port(request):
if django.VERSION >= (1, 9):
return request.get_port()
host_parts = request.get_host().partition(':')
return host_parts[2] or request.META['SERVER_PORT']

View File

@@ -0,0 +1,10 @@
from django.apps import AppConfig
class PythonSocialAuthConfig(AppConfig):
# Full Python path to the application eg. 'django.contrib.admin'.
name = 'social_django'
# Last component of the Python path to the application eg. 'admin'.
label = 'social_django'
# Human-readable name for the application eg. "Admin".
verbose_name = 'Python Social Auth'

View File

@@ -0,0 +1,52 @@
from django.contrib.auth import REDIRECT_FIELD_NAME
from django.utils.functional import SimpleLazyObject
from django.utils.http import urlquote
try:
from django.utils.functional import empty as _empty
empty = _empty
except ImportError: # django < 1.4
empty = None
from social_core.backends.utils import user_backends_data
from .utils import Storage, BACKENDS
class LazyDict(SimpleLazyObject):
"""Lazy dict initialization."""
def __getitem__(self, name):
if self._wrapped is empty:
self._setup()
return self._wrapped[name]
def __setitem__(self, name, value):
if self._wrapped is empty:
self._setup()
self._wrapped[name] = value
def backends(request):
"""Load Social Auth current user data to context under the key 'backends'.
Will return the output of social_core.backends.utils.user_backends_data."""
return {'backends': LazyDict(lambda: user_backends_data(request.user,
BACKENDS,
Storage))}
def login_redirect(request):
"""Load current redirect to context."""
value = request.method == 'POST' and \
request.POST.get(REDIRECT_FIELD_NAME) or \
request.GET.get(REDIRECT_FIELD_NAME)
if value:
value = urlquote(value)
querystring = REDIRECT_FIELD_NAME + '=' + value
else:
querystring = ''
return {
'REDIRECT_FIELD_NAME': REDIRECT_FIELD_NAME,
'REDIRECT_FIELD_VALUE': value,
'REDIRECT_QUERYSTRING': querystring
}

View File

@@ -0,0 +1,94 @@
import json
import six
import functools
import django
from django.core.exceptions import ValidationError
from django.conf import settings
from django.db import models
from social_core.utils import setting_name
try:
from django.utils.encoding import smart_unicode as smart_text
smart_text # placate pyflakes
except ImportError:
from django.utils.encoding import smart_text
# SubfieldBase causes RemovedInDjango110Warning in 1.8 and 1.9, and
# will not work in 1.10 or later
if django.VERSION[:2] >= (1, 8):
field_metaclass = type
else:
from django.db.models import SubfieldBase
field_metaclass = SubfieldBase
field_class = functools.partial(six.with_metaclass, field_metaclass)
if getattr(settings, setting_name('POSTGRES_JSONFIELD'), False):
from django.contrib.postgres.fields import JSONField as JSONFieldBase
else:
JSONFieldBase = field_class(models.TextField)
class JSONField(JSONFieldBase):
"""Simple JSON field that stores python structures as JSON strings
on database.
"""
def __init__(self, *args, **kwargs):
kwargs.setdefault('default', dict)
super(JSONField, self).__init__(*args, **kwargs)
def from_db_value(self, value, expression, connection, context):
return self.to_python(value)
def to_python(self, value):
"""
Convert the input JSON value into python structures, raises
django.core.exceptions.ValidationError if the data can't be converted.
"""
if self.blank and not value:
return {}
value = value or '{}'
if isinstance(value, six.binary_type):
value = six.text_type(value, 'utf-8')
if isinstance(value, six.string_types):
try:
# with django 1.6 i have '"{}"' as default value here
if value[0] == value[-1] == '"':
value = value[1:-1]
return json.loads(value)
except Exception as err:
raise ValidationError(str(err))
else:
return value
def validate(self, value, model_instance):
"""Check value is a valid JSON string, raise ValidationError on
error."""
if isinstance(value, six.string_types):
super(JSONField, self).validate(value, model_instance)
try:
json.loads(value)
except Exception as err:
raise ValidationError(str(err))
def get_prep_value(self, value):
"""Convert value to JSON string before save"""
try:
return json.dumps(value)
except Exception as err:
raise ValidationError(str(err))
def value_to_string(self, obj):
"""Return value from object converted to string properly"""
return smart_text(self.value_from_object(obj))
def value_from_object(self, obj):
"""Return value dumped to string."""
orig_val = super(JSONField, self).value_from_object(obj)
return self.get_prep_value(orig_val)

View File

@@ -0,0 +1,35 @@
from datetime import timedelta
from django.core.management.base import BaseCommand
from django.utils import timezone
from social_django.models import Code, Partial
class Command(BaseCommand):
help = 'removes old not used verification codes and partials'
def add_arguments(self, parser):
super(Command, self).add_arguments(parser)
parser.add_argument(
'--age',
action='store',
type=int,
dest='age',
default=14,
help='how long to keep unused data (in days, defaults to 14)'
)
def handle(self, *args, **options):
age = timezone.now() - timedelta(days=options['age'])
# Delete old not verified codes
Code.objects.filter(
verified=False,
timestamp__lt=age
).delete()
# Delete old partial data
Partial.objects.filter(
timestamp__lt=age
).delete()

View File

@@ -0,0 +1,15 @@
from django.db import models
class UserSocialAuthManager(models.Manager):
"""Manager for the UserSocialAuth django model."""
class Meta:
app_label = "social_django"
def get_social_auth(self, provider, uid):
try:
return self.select_related('user').get(provider=provider,
uid=uid)
except self.model.DoesNotExist:
return None

View File

@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
import six
from django.apps import apps
from django.conf import settings
from django.contrib import messages
from django.contrib.messages.api import MessageFailure
from django.core.urlresolvers import reverse
from django.shortcuts import redirect
from django.utils.http import urlquote
from social_core.exceptions import SocialAuthBaseException
from social_core.utils import social_logger
from .compat import MiddlewareMixin
class SocialAuthExceptionMiddleware(MiddlewareMixin):
"""Middleware that handles Social Auth AuthExceptions by providing the user
with a message, logging an error, and redirecting to some next location.
By default, the exception message itself is sent to the user and they are
redirected to the location specified in the SOCIAL_AUTH_LOGIN_ERROR_URL
setting.
This middleware can be extended by overriding the get_message or
get_redirect_uri methods, which each accept request and exception.
"""
def process_exception(self, request, exception):
strategy = getattr(request, 'social_strategy', None)
if strategy is None or self.raise_exception(request, exception):
return
if isinstance(exception, SocialAuthBaseException):
backend = getattr(request, 'backend', None)
backend_name = getattr(backend, 'name', 'unknown-backend')
message = self.get_message(request, exception)
url = self.get_redirect_uri(request, exception)
if apps.is_installed('django.contrib.messages'):
social_logger.info(message)
try:
messages.error(request, message,
extra_tags='social-auth ' + backend_name)
except MessageFailure:
if url:
url += ('?' in url and '&' or '?') + \
'message={0}&backend={1}'.format(urlquote(message),
backend_name)
else:
social_logger.error(message)
if url:
return redirect(url)
else:
return redirect(reverse('edit_profile'))
def raise_exception(self, request, exception):
strategy = getattr(request, 'social_strategy', None)
if strategy is not None:
return strategy.setting('RAISE_EXCEPTIONS') or settings.DEBUG
def get_message(self, request, exception):
return six.text_type(exception)
def get_redirect_uri(self, request, exception):
strategy = getattr(request, 'social_strategy', None)
return strategy.setting('LOGIN_ERROR_URL')

View File

@@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.11 on 2018-05-25 03:27
from __future__ import unicode_literals
from django.db import migrations, models
import seahub.base.fields
import social_django.fields
import social_django.storage
class Migration(migrations.Migration):
initial = True
dependencies = [
]
operations = [
migrations.CreateModel(
name='Association',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('server_url', models.CharField(max_length=255)),
('handle', models.CharField(max_length=255)),
('secret', models.CharField(max_length=255)),
('issued', models.IntegerField()),
('lifetime', models.IntegerField()),
('assoc_type', models.CharField(max_length=64)),
],
options={
'db_table': 'social_auth_association',
},
bases=(models.Model, social_django.storage.DjangoAssociationMixin),
),
migrations.CreateModel(
name='Code',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('email', models.EmailField(max_length=254)),
('code', models.CharField(db_index=True, max_length=32)),
('verified', models.BooleanField(default=False)),
('timestamp', models.DateTimeField(auto_now_add=True, db_index=True)),
],
options={
'db_table': 'social_auth_code',
},
bases=(models.Model, social_django.storage.DjangoCodeMixin),
),
migrations.CreateModel(
name='Nonce',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('server_url', models.CharField(max_length=255)),
('timestamp', models.IntegerField()),
('salt', models.CharField(max_length=65)),
],
options={
'db_table': 'social_auth_nonce',
},
bases=(models.Model, social_django.storage.DjangoNonceMixin),
),
migrations.CreateModel(
name='Partial',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('token', models.CharField(db_index=True, max_length=32)),
('next_step', models.PositiveSmallIntegerField(default=0)),
('backend', models.CharField(max_length=32)),
('data', social_django.fields.JSONField(default=dict)),
('timestamp', models.DateTimeField(auto_now_add=True, db_index=True)),
],
options={
'db_table': 'social_auth_partial',
},
bases=(models.Model, social_django.storage.DjangoPartialMixin),
),
migrations.CreateModel(
name='UserSocialAuth',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('username', seahub.base.fields.LowerCaseCharField(db_index=True, max_length=255)),
('provider', models.CharField(max_length=32)),
('uid', models.CharField(max_length=255)),
('extra_data', social_django.fields.JSONField(default=dict)),
],
options={
'db_table': 'social_auth_usersocialauth',
},
bases=(models.Model, social_django.storage.DjangoUserMixin),
),
migrations.AlterUniqueTogether(
name='usersocialauth',
unique_together=set([('provider', 'uid')]),
),
migrations.AlterUniqueTogether(
name='nonce',
unique_together=set([('server_url', 'timestamp', 'salt')]),
),
migrations.AlterUniqueTogether(
name='code',
unique_together=set([('email', 'code')]),
),
migrations.AlterUniqueTogether(
name='association',
unique_together=set([('server_url', 'handle')]),
),
]

View File

@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.15 on 2018-11-15 08:25
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('social_django', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='usersocialauth',
name='uid',
field=models.CharField(max_length=150),
),
]

View File

@@ -0,0 +1,150 @@
"""Django ORM models for Social Auth"""
import six
from django.db import models
from django.conf import settings
from django.db.utils import IntegrityError
from social_core.utils import setting_name
from seahub.base.accounts import User
from seahub.base.fields import LowerCaseCharField
from .compat import get_rel_model
from .storage import DjangoUserMixin, DjangoAssociationMixin, \
DjangoNonceMixin, DjangoCodeMixin, \
DjangoPartialMixin, BaseDjangoStorage
from .fields import JSONField
from .managers import UserSocialAuthManager
USER_MODEL = getattr(settings, setting_name('USER_MODEL'), None) or \
getattr(settings, 'AUTH_USER_MODEL', None) or \
'auth.User'
UID_LENGTH = getattr(settings, setting_name('UID_LENGTH'), 150)
EMAIL_LENGTH = getattr(settings, setting_name('EMAIL_LENGTH'), 254)
NONCE_SERVER_URL_LENGTH = getattr(
settings, setting_name('NONCE_SERVER_URL_LENGTH'), 255)
ASSOCIATION_SERVER_URL_LENGTH = getattr(
settings, setting_name('ASSOCIATION_SERVER_URL_LENGTH'), 255)
ASSOCIATION_HANDLE_LENGTH = getattr(
settings, setting_name('ASSOCIATION_HANDLE_LENGTH'), 255)
class AbstractUserSocialAuth(models.Model, DjangoUserMixin):
"""Abstract Social Auth association model"""
# user = models.ForeignKey(USER_MODEL, related_name='social_auth',
# on_delete=models.CASCADE)
username = LowerCaseCharField(max_length=255, db_index=True)
provider = models.CharField(max_length=32)
uid = models.CharField(max_length=UID_LENGTH)
extra_data = JSONField()
objects = UserSocialAuthManager()
def __str__(self):
return str(self.username)
class Meta:
app_label = "social_django"
abstract = True
@classmethod
def get_social_auth(cls, provider, uid):
try:
social_auth = cls.objects.get(provider=provider, uid=uid)
except cls.DoesNotExist:
return None
try:
u = User.objects.get(email=social_auth.username)
social_auth.user = u
except User.DoesNotExist:
social_auth.user = None
return social_auth
@classmethod
def username_max_length(cls):
return 255
# username_field = cls.username_field()
# field = cls.user_model()._meta.get_field(username_field)
# return field.max_length
@classmethod
def user_model(cls):
return User
# user_model = get_rel_model(field=cls._meta.get_field('user'))
# return user_model
class UserSocialAuth(AbstractUserSocialAuth):
"""Social Auth association model"""
class Meta:
"""Meta data"""
app_label = "social_django"
unique_together = ('provider', 'uid')
db_table = 'social_auth_usersocialauth'
class Nonce(models.Model, DjangoNonceMixin):
"""One use numbers"""
server_url = models.CharField(max_length=NONCE_SERVER_URL_LENGTH)
timestamp = models.IntegerField()
salt = models.CharField(max_length=65)
class Meta:
app_label = "social_django"
unique_together = ('server_url', 'timestamp', 'salt')
db_table = 'social_auth_nonce'
class Association(models.Model, DjangoAssociationMixin):
"""OpenId account association"""
server_url = models.CharField(max_length=ASSOCIATION_SERVER_URL_LENGTH)
handle = models.CharField(max_length=ASSOCIATION_HANDLE_LENGTH)
secret = models.CharField(max_length=255) # Stored base64 encoded
issued = models.IntegerField()
lifetime = models.IntegerField()
assoc_type = models.CharField(max_length=64)
class Meta:
app_label = "social_django"
db_table = 'social_auth_association'
unique_together = (
('server_url', 'handle',)
)
class Code(models.Model, DjangoCodeMixin):
email = models.EmailField(max_length=EMAIL_LENGTH)
code = models.CharField(max_length=32, db_index=True)
verified = models.BooleanField(default=False)
timestamp = models.DateTimeField(auto_now_add=True, db_index=True)
class Meta:
app_label = "social_django"
db_table = 'social_auth_code'
unique_together = ('email', 'code')
class Partial(models.Model, DjangoPartialMixin):
token = models.CharField(max_length=32, db_index=True)
next_step = models.PositiveSmallIntegerField(default=0)
backend = models.CharField(max_length=32)
data = JSONField()
timestamp = models.DateTimeField(auto_now_add=True, db_index=True)
class Meta:
app_label = "social_django"
db_table = 'social_auth_partial'
class DjangoStorage(BaseDjangoStorage):
user = UserSocialAuth
nonce = Nonce
association = Association
code = Code
partial = Partial
@classmethod
def is_integrity_error(cls, exception):
return exception.__class__ is IntegrityError

View File

@@ -0,0 +1,220 @@
"""Django ORM models for Social Auth"""
import base64
import six
import sys
from django.db import transaction
from django.db.utils import IntegrityError
from social_core.storage import UserMixin, AssociationMixin, NonceMixin, \
CodeMixin, PartialMixin, BaseStorage
from seahub.base.accounts import User
class DjangoUserMixin(UserMixin):
"""Social Auth association model"""
@classmethod
def changed(cls, user):
user.save()
def set_extra_data(self, extra_data=None):
if super(DjangoUserMixin, self).set_extra_data(extra_data):
self.save()
@classmethod
def allowed_to_disconnect(cls, user, backend_name, association_id=None):
if association_id is not None:
qs = cls.objects.exclude(id=association_id)
else:
qs = cls.objects.exclude(provider=backend_name)
qs = qs.filter(username=user.username)
if hasattr(user, 'has_usable_password'):
valid_password = user.has_usable_password()
else:
valid_password = True
return valid_password or qs.count() > 0
@classmethod
def disconnect(cls, entry):
entry.delete()
@classmethod
def username_field(cls):
return 'username'
# return getattr(cls.user_model(), 'USERNAME_FIELD', 'username')
@classmethod
def user_exists(cls, *args, **kwargs):
"""
Return True/False if a User instance exists with the given arguments.
Arguments are directly passed to filter() manager method.
"""
if 'username' in kwargs:
kwargs[cls.username_field()] = kwargs.pop('username')
assert 'username' in kwargs
try:
User.objects.get(email=kwargs['username'])
return True
except User.DoesNotExist:
return False
# return cls.user_model().objects.filter(*args, **kwargs).count() > 0
@classmethod
def get_username(cls, user):
return getattr(user, cls.username_field(), None)
@classmethod
def create_user(cls, *args, **kwargs):
username_field = cls.username_field()
if 'username' in kwargs and username_field not in kwargs:
kwargs[username_field] = kwargs.pop('username')
assert 'username' in kwargs
user = User.objects.create_user(email=kwargs['username'],
is_active=True,
save_profile=False)
# try:
# if hasattr(transaction, 'atomic'):
# # In Django versions that have an "atomic" transaction decorator / context
# # manager, there's a transaction wrapped around this call.
# # If the create fails below due to an IntegrityError, ensure that the transaction
# # stays undamaged by wrapping the create in an atomic.
# with transaction.atomic():
# user = cls.user_model().objects.create_user(*args, **kwargs)
# else:
# user = cls.user_model().objects.create_user(*args, **kwargs)
# except IntegrityError:
# # User might have been created on a different thread, try and find them.
# # If we don't, re-raise the IntegrityError.
# exc_info = sys.exc_info()
# # If email comes in as None it won't get found in the get
# if kwargs.get('email', True) is None:
# kwargs['email'] = ''
# try:
# user = cls.user_model().objects.get(*args, **kwargs)
# except cls.user_model().DoesNotExist:
# six.reraise(*exc_info)
return user
@classmethod
def get_user(cls, pk=None, **kwargs):
if pk:
kwargs = {'pk': pk}
try:
return User.objects.get(email=pk)
except User.DoesNotExist:
return None
# try:
# return cls.user_model().objects.get(**kwargs)
# except cls.user_model().DoesNotExist:
# return None
@classmethod
def get_users_by_email(cls, email):
user_model = cls.user_model()
email_field = getattr(user_model, 'EMAIL_FIELD', 'email')
return user_model.objects.filter(**{email_field + '__iexact': email})
@classmethod
def get_social_auth(cls, provider, uid):
if not isinstance(uid, six.string_types):
uid = str(uid)
try:
return cls.objects.get(provider=provider, uid=uid)
except cls.DoesNotExist:
return None
@classmethod
def get_social_auth_for_user(cls, user, provider=None, id=None):
qs = cls.objects.filter(username=user.username)
if provider:
qs = qs.filter(provider=provider)
if id:
qs = qs.filter(id=id)
return qs
@classmethod
def create_social_auth(cls, user, uid, provider):
if not isinstance(uid, six.string_types):
uid = str(uid)
if hasattr(transaction, 'atomic'):
# In Django versions that have an "atomic" transaction decorator / context
# manager, there's a transaction wrapped around this call.
# If the create fails below due to an IntegrityError, ensure that the transaction
# stays undamaged by wrapping the create in an atomic.
with transaction.atomic():
social_auth = cls.objects.create(username=user.username, uid=uid, provider=provider)
else:
social_auth = cls.objects.create(username=user.username, uid=uid, provider=provider)
return social_auth
class DjangoNonceMixin(NonceMixin):
@classmethod
def use(cls, server_url, timestamp, salt):
return cls.objects.get_or_create(server_url=server_url,
timestamp=timestamp,
salt=salt)[1]
class DjangoAssociationMixin(AssociationMixin):
@classmethod
def store(cls, server_url, association):
# Don't use get_or_create because issued cannot be null
try:
assoc = cls.objects.get(server_url=server_url,
handle=association.handle)
except cls.DoesNotExist:
assoc = cls(server_url=server_url,
handle=association.handle)
assoc.secret = base64.encodestring(association.secret)
assoc.issued = association.issued
assoc.lifetime = association.lifetime
assoc.assoc_type = association.assoc_type
assoc.save()
@classmethod
def get(cls, *args, **kwargs):
return cls.objects.filter(*args, **kwargs)
@classmethod
def remove(cls, ids_to_delete):
cls.objects.filter(pk__in=ids_to_delete).delete()
class DjangoCodeMixin(CodeMixin):
@classmethod
def get_code(cls, code):
try:
return cls.objects.get(code=code)
except cls.DoesNotExist:
return None
class DjangoPartialMixin(PartialMixin):
@classmethod
def load(cls, token):
try:
return cls.objects.get(token=token)
except cls.DoesNotExist:
return None
@classmethod
def destroy(cls, token):
partial = cls.load(token)
if partial:
partial.delete()
class BaseDjangoStorage(BaseStorage):
user = DjangoUserMixin
nonce = DjangoNonceMixin
association = DjangoAssociationMixin
code = DjangoCodeMixin

View File

@@ -0,0 +1,159 @@
# coding=utf-8
from django.conf import settings
from django.http import HttpResponse, HttpRequest
from django.db.models import Model
from django.contrib.contenttypes.models import ContentType
from django.contrib.auth import authenticate
from django.shortcuts import redirect, resolve_url
from django.template import TemplateDoesNotExist, loader, engines
from django.utils.crypto import get_random_string
from django.utils.encoding import force_text
from django.utils.functional import Promise
from django.utils.translation import get_language
from social_core.strategy import BaseStrategy, BaseTemplateStrategy
from .compat import get_request_port
def render_template_string(request, html, context=None):
"""Take a template in the form of a string and render it for the
given context"""
template = engines['django'].from_string(html)
return template.render(context=context, request=request)
class DjangoTemplateStrategy(BaseTemplateStrategy):
def render_template(self, tpl, context):
template = loader.get_template(tpl)
return template.render(context=context, request=self.strategy.request)
def render_string(self, html, context):
return render_template_string(self.strategy.request, html, context)
class DjangoStrategy(BaseStrategy):
DEFAULT_TEMPLATE_STRATEGY = DjangoTemplateStrategy
def __init__(self, storage, request=None, tpl=None):
self.request = request
self.session = request.session if request else {}
super(DjangoStrategy, self).__init__(storage, tpl)
def get_setting(self, name):
value = getattr(settings, name)
# Force text on URL named settings that are instance of Promise
if name.endswith('_URL'):
if isinstance(value, Promise):
value = force_text(value)
value = resolve_url(value)
return value
def request_data(self, merge=True):
if not self.request:
return {}
if merge:
data = self.request.GET.copy()
data.update(self.request.POST)
elif self.request.method == 'POST':
data = self.request.POST
else:
data = self.request.GET
return data
def request_host(self):
if self.request:
return self.request.get_host()
def request_is_secure(self):
"""Is the request using HTTPS?"""
return self.request.is_secure()
def request_path(self):
"""path of the current request"""
return self.request.path
def request_port(self):
"""Port in use for this request"""
return get_request_port(request=self.request)
def request_get(self):
"""Request GET data"""
return self.request.GET.copy()
def request_post(self):
"""Request POST data"""
return self.request.POST.copy()
def redirect(self, url):
return redirect(url)
def html(self, content):
return HttpResponse(content, content_type='text/html;charset=UTF-8')
def render_html(self, tpl=None, html=None, context=None):
if not tpl and not html:
raise ValueError('Missing template or html parameters')
context = context or {}
try:
template = loader.get_template(tpl)
return template.render(context=context, request=self.request)
except TemplateDoesNotExist:
return render_template_string(self.request, html, context)
def authenticate(self, backend, *args, **kwargs):
kwargs['strategy'] = self
kwargs['storage'] = self.storage
kwargs['backend'] = backend
return authenticate(*args, **kwargs)
def clean_authenticate_args(self, *args, **kwargs):
"""Cleanup request argument if present, which is passed to
authenticate as for Django 1.11"""
if len(args) > 0 and isinstance(args[0], HttpRequest):
kwargs['request'], args = args[0], args[1:]
return args, kwargs
def session_get(self, name, default=None):
return self.session.get(name, default)
def session_set(self, name, value):
self.session[name] = value
if hasattr(self.session, 'modified'):
self.session.modified = True
def session_pop(self, name):
return self.session.pop(name, None)
def session_setdefault(self, name, value):
return self.session.setdefault(name, value)
def build_absolute_uri(self, path=None):
if self.request:
return self.request.build_absolute_uri(path)
else:
return path
def random_string(self, length=12, chars=BaseStrategy.ALLOWED_CHARS):
return get_random_string(length, chars)
def to_session_value(self, val):
"""Converts values that are instance of Model to a dictionary
with enough information to retrieve the instance back later."""
if isinstance(val, Model):
val = {
'pk': val.pk,
'ctype': ContentType.objects.get_for_model(val).pk
}
return val
def from_session_value(self, val):
"""Converts back the instance saved by self._ctype function."""
if isinstance(val, dict) and 'pk' in val and 'ctype' in val:
ctype = ContentType.objects.get_for_id(val['ctype'])
ModelClass = ctype.model_class()
val = ModelClass.objects.get(pk=val['pk'])
return val
def get_language(self):
"""Return current language"""
return get_language()

View File

@@ -0,0 +1,24 @@
"""URLs module"""
from django.conf import settings
from django.conf.urls import url
from social_core.utils import setting_name
from . import views
extra = getattr(settings, setting_name('TRAILING_SLASH'), True) and '/' or ''
app_name = 'social'
urlpatterns = [
# authentication / association
url(r'^login/(?P<backend>[^/]+){0}$'.format(extra), views.auth,
name='begin'),
url(r'^complete/(?P<backend>[^/]+){0}$'.format(extra), views.complete,
name='complete'),
# disconnection
url(r'^disconnect/(?P<backend>[^/]+){0}$'.format(extra), views.disconnect,
name='disconnect'),
url(r'^disconnect/(?P<backend>[^/]+)/(?P<association_id>\d+){0}$'
.format(extra), views.disconnect, name='disconnect_individual'),
]

View File

@@ -0,0 +1,51 @@
# coding=utf-8
from functools import wraps
from django.conf import settings
from django.http import Http404
from social_core.utils import setting_name, module_member, get_strategy
from social_core.exceptions import MissingBackend
from social_core.backends.utils import get_backend
from .compat import reverse
BACKENDS = settings.AUTHENTICATION_BACKENDS
STRATEGY = getattr(settings, setting_name('STRATEGY'),
'social_django.strategy.DjangoStrategy')
STORAGE = getattr(settings, setting_name('STORAGE'),
'social_django.models.DjangoStorage')
Strategy = module_member(STRATEGY)
Storage = module_member(STORAGE)
def load_strategy(request=None):
return get_strategy(STRATEGY, STORAGE, request)
def load_backend(strategy, name, redirect_uri):
Backend = get_backend(BACKENDS, name)
return Backend(strategy, redirect_uri)
def psa(redirect_uri=None, load_strategy=load_strategy):
def decorator(func):
@wraps(func)
def wrapper(request, backend, *args, **kwargs):
uri = redirect_uri
if uri and not uri.startswith('/'):
uri = reverse(redirect_uri, args=(backend,))
request.social_strategy = load_strategy(request)
# backward compatibility in attribute name, only if not already
# defined
if not hasattr(request, 'strategy'):
request.strategy = request.social_strategy
try:
request.backend = load_backend(request.social_strategy,
backend, uri)
except MissingBackend:
raise Http404('Backend not found')
return func(request, backend, *args, **kwargs)
return wrapper
return decorator

View File

@@ -0,0 +1,131 @@
from django.conf import settings
from django.contrib.auth import REDIRECT_FIELD_NAME
from django.contrib.auth.decorators import login_required
from django.views.decorators.csrf import csrf_exempt, csrf_protect
from django.views.decorators.http import require_POST
from django.views.decorators.cache import never_cache
from seahub.auth import login
from social_core.utils import setting_name
from social_core.actions import do_auth, do_complete, do_disconnect
from .utils import psa
NAMESPACE = getattr(settings, setting_name('URL_NAMESPACE'), None) or 'social'
# Calling `session.set_expiry(None)` results in a session lifetime equal to
# platform default session lifetime.
DEFAULT_SESSION_TIMEOUT = None
@never_cache
@psa('{0}:complete'.format(NAMESPACE))
def auth(request, backend):
return do_auth(request.backend, redirect_name=REDIRECT_FIELD_NAME)
@never_cache
@csrf_exempt
@psa('{0}:complete'.format(NAMESPACE))
def complete(request, backend, *args, **kwargs):
"""Authentication complete view"""
return do_complete(request.backend, _do_login, request.user,
redirect_name=REDIRECT_FIELD_NAME, request=request,
*args, **kwargs)
@never_cache
@login_required
@psa()
@require_POST
@csrf_protect
def disconnect(request, backend, association_id=None):
"""Disconnects given backend from current logged in user."""
return do_disconnect(request.backend, request.user, association_id,
redirect_name=REDIRECT_FIELD_NAME)
def get_session_timeout(social_user, enable_session_expiration=False,
max_session_length=None):
if enable_session_expiration:
# Retrieve an expiration date from the social user who just finished
# logging in; this value was set by the social auth backend, and was
# typically received from the server.
expiration = social_user.expiration_datetime()
# We've enabled session expiration. Check to see if we got
# a specific expiration time from the provider for this user;
# if not, use the platform default expiration.
if expiration:
received_expiration_time = expiration.total_seconds()
else:
received_expiration_time = DEFAULT_SESSION_TIMEOUT
# Check to see if the backend set a value as a maximum length
# that a session may be; if they did, then we should use the minimum
# of that and the received session expiration time, if any, to
# set the session length.
if received_expiration_time is None and max_session_length is None:
# We neither received an expiration length, nor have a maximum
# session length. Use the platform default.
session_expiry = DEFAULT_SESSION_TIMEOUT
elif received_expiration_time is None and max_session_length is not None:
# We only have a maximum session length; use that.
session_expiry = max_session_length
elif received_expiration_time is not None and max_session_length is None:
# We only have an expiration time received by the backend
# from the provider, with no set maximum. Use that.
session_expiry = received_expiration_time
else:
# We received an expiration time from the backend, and we also
# have a set maximum session length. Use the smaller of the two.
session_expiry = min(received_expiration_time, max_session_length)
else:
# If there's an explicitly-set maximum session length, use that
# even if we don't want to retrieve session expiry times from
# the backend. If there isn't, then use the platform default.
if max_session_length is None:
session_expiry = DEFAULT_SESSION_TIMEOUT
else:
session_expiry = max_session_length
return session_expiry
def _do_login(backend, user, social_user):
user.backend = '{0}.{1}'.format(backend.__module__,
backend.__class__.__name__)
# Get these details early to avoid any issues involved in the
# session switch that happens when we call login().
enable_session_expiration = backend.setting('SESSION_EXPIRATION', False)
max_session_length_setting = backend.setting('MAX_SESSION_LENGTH', None)
# Log the user in, creating a new session.
login(backend.strategy.request, user)
# Make sure that the max_session_length value is either an integer or
# None. Because we get this as a setting from the backend, it can be set
# to whatever the backend creator wants; we want to be resilient against
# unexpected types being presented to us.
try:
max_session_length = int(max_session_length_setting)
except (TypeError, ValueError):
# We got a response that doesn't look like a number; use the default.
max_session_length = None
# Get the session expiration length based on the maximum session length
# setting, combined with any session length received from the backend.
session_expiry = get_session_timeout(
social_user,
enable_session_expiration=enable_session_expiration,
max_session_length=max_session_length,
)
try:
# Set the session length to our previously determined expiry length.
backend.strategy.request.session.set_expiry(session_expiry)
except OverflowError:
# The timestamp we used wasn't in the range of values supported by
# Django for session length; use the platform default. We tried.
backend.strategy.request.session.set_expiry(DEFAULT_SESSION_TIMEOUT)

View File

@@ -0,0 +1,144 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
##
# Copyright (C) 2018 All rights reserved.
#
# @File AbstractApi.py
# @Brief
# @Author abelzhu, abelzhu@tencent.com
# @Version 1.0
# @Date 2018-02-24
#
#
import sys
import os
import re
import json
import requests
DEBUG = False
class ApiException(Exception) :
def __init__(self, errCode, errMsg) :
self.errCode = errCode
self.errMsg = errMsg
class AbstractApi(object) :
def __init__(self) :
return
def getAccessToken(self) :
raise NotImplementedError
def refreshAccessToken(self) :
raise NotImplementedError
def getSuiteAccessToken(self) :
raise NotImplementedError
def refreshSuiteAccessToken(self) :
raise NotImplementedError
def getProviderAccessToken(self) :
raise NotImplementedError
def refreshProviderAccessToken(self) :
raise NotImplementedError
def httpCall(self, urlType, args=None) :
shortUrl = urlType[0]
method = urlType[1]
response = {}
for retryCnt in range(0, 3) :
if 'POST' == method :
url = self.__makeUrl(shortUrl)
response = self.__httpPost(url, args)
elif 'GET' == method :
url = self.__makeUrl(shortUrl)
url = self.__appendArgs(url, args)
response = self.__httpGet(url)
else :
raise ApiException(-1, "unknown method type")
# check if token expired
if self.__tokenExpired(response.get('errcode')) :
self.__refreshToken(shortUrl)
retryCnt += 1
continue
else :
break
return self.__checkResponse(response)
@staticmethod
def __appendArgs(url, args) :
if args is None :
return url
for key, value in args.items() :
if '?' in url :
url += ('&' + key + '=' + value)
else :
url += ('?' + key + '=' + value)
return url
@staticmethod
def __makeUrl(shortUrl) :
base = "https://qyapi.weixin.qq.com"
if shortUrl[0] == '/' :
return base + shortUrl
else :
return base + '/' + shortUrl
def __appendToken(self, url) :
if 'SUITE_ACCESS_TOKEN' in url :
return url.replace('SUITE_ACCESS_TOKEN', self.getSuiteAccessToken())
elif 'PROVIDER_ACCESS_TOKEN' in url :
return url.replace('PROVIDER_ACCESS_TOKEN', self.getProviderAccessToken())
elif 'ACCESS_TOKEN' in url :
return url.replace('ACCESS_TOKEN', self.getAccessToken())
else :
return url
def __httpPost(self, url, args) :
realUrl = self.__appendToken(url)
if DEBUG is True :
print realUrl, args
return requests.post(realUrl, data = json.dumps(args, ensure_ascii = False).encode('utf-8')).json()
def __httpGet(self, url) :
realUrl = self.__appendToken(url)
if DEBUG is True :
print realUrl
return requests.get(realUrl).json()
def __post_file(self, url, media_file):
return requests.post(url, file=media_file).json()
@staticmethod
def __checkResponse(response):
errCode = response.get('errcode')
errMsg = response.get('errmsg')
if errCode is 0:
return response
else:
raise ApiException(errCode, errMsg)
@staticmethod
def __tokenExpired(errCode) :
if errCode == 40014 or errCode == 42001 or errCode == 42007 or errCode == 42009 :
return True
else :
return False
def __refreshToken(self, url) :
if 'SUITE_ACCESS_TOKEN' in url :
self.refreshSuiteAccessToken()
elif 'PROVIDER_ACCESS_TOKEN' in url :
self.refreshProviderAccessToken()
elif 'ACCESS_TOKEN' in url :
self.refreshAccessToken()

View File

@@ -0,0 +1,104 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
##
# Copyright (C) 2018 All rights reserved.
#
# @File CorpApi.py
# @Brief
# @Author abelzhu, abelzhu@tencent.com
# @Version 1.0
# @Date 2018-02-24
#
#
from .AbstractApi import *
CORP_API_TYPE = {
'GET_ACCESS_TOKEN' : ['/cgi-bin/gettoken', 'GET'],
'USER_CREATE' : ['/cgi-bin/user/create?access_token=ACCESS_TOKEN', 'POST'],
'USER_GET' : ['/cgi-bin/user/get?access_token=ACCESS_TOKEN', 'GET'],
'USER_UPDATE' : ['/cgi-bin/user/update?access_token=ACCESS_TOKEN', 'POST'],
'USER_DELETE' : ['/cgi-bin/user/delete?access_token=ACCESS_TOKEN', 'GET'],
'USER_BATCH_DELETE': ['/cgi-bin/user/batchdelete?access_token=ACCESS_TOKEN', 'POST'],
'USER_SIMPLE_LIST ': ['/cgi-bin/user/simplelist?access_token=ACCESS_TOKEN', 'GET'],
'USER_LIST' : ['/cgi-bin/user/list?access_token=ACCESS_TOKEN', 'GET'],
'USERID_TO_OPENID' : ['/cgi-bin/user/convert_to_openid?access_token=ACCESS_TOKEN', 'POST'],
'OPENID_TO_USERID' : ['/cgi-bin/user/convert_to_userid?access_token=ACCESS_TOKEN', 'POST'],
'USER_AUTH_SUCCESS': ['/cgi-bin/user/authsucc?access_token=ACCESS_TOKEN', 'GET'],
'DEPARTMENT_CREATE': ['/cgi-bin/department/create?access_token=ACCESS_TOKEN', 'POST'],
'DEPARTMENT_UPDATE': ['/cgi-bin/department/update?access_token=ACCESS_TOKEN', 'POST'],
'DEPARTMENT_DELETE': ['/cgi-bin/department/delete?access_token=ACCESS_TOKEN', 'GET'],
'DEPARTMENT_LIST' : ['/cgi-bin/department/list?access_token=ACCESS_TOKEN', 'GET'],
'TAG_CREATE' : ['/cgi-bin/tag/create?access_token=ACCESS_TOKEN', 'POST'],
'TAG_UPDATE' : ['/cgi-bin/tag/update?access_token=ACCESS_TOKEN', 'POST'],
'TAG_DELETE' : ['/cgi-bin/tag/delete?access_token=ACCESS_TOKEN', 'GET'],
'TAG_GET_USER' : ['/cgi-bin/tag/get?access_token=ACCESS_TOKEN', 'GET'],
'TAG_ADD_USER' : ['/cgi-bin/tag/addtagusers?access_token=ACCESS_TOKEN', 'POST'],
'TAG_DELETE_USER' : ['/cgi-bin/tag/deltagusers?access_token=ACCESS_TOKEN', 'POST'],
'TAG_GET_LIST' : ['/cgi-bin/tag/list?access_token=ACCESS_TOKEN', 'GET'],
'BATCH_JOB_GET_RESULT' : ['/cgi-bin/batch/getresult?access_token=ACCESS_TOKEN', 'GET'],
'BATCH_INVITE' : ['/cgi-bin/batch/invite?access_token=ACCESS_TOKEN', 'POST'],
'AGENT_GET' : ['/cgi-bin/agent/get?access_token=ACCESS_TOKEN', 'GET'],
'AGENT_SET' : ['/cgi-bin/agent/set?access_token=ACCESS_TOKEN', 'POST'],
'AGENT_GET_LIST' : ['/cgi-bin/agent/list?access_token=ACCESS_TOKEN', 'GET'],
'MENU_CREATE' : ['/cgi-bin/menu/create?access_token=ACCESS_TOKEN', 'POST'], ## TODO
'MENU_GET' : ['/cgi-bin/menu/get?access_token=ACCESS_TOKEN', 'GET'],
'MENU_DELETE' : ['/cgi-bin/menu/delete?access_token=ACCESS_TOKEN', 'GET'],
'MESSAGE_SEND' : ['/cgi-bin/message/send?access_token=ACCESS_TOKEN', 'POST'],
'MESSAGE_REVOKE' : ['/cgi-bin/message/revoke?access_token=ACCESS_TOKEN', 'POST'],
'MEDIA_GET' : ['/cgi-bin/media/get?access_token=ACCESS_TOKEN', 'GET'],
'GET_USER_INFO_BY_CODE' : ['/cgi-bin/user/getuserinfo?access_token=ACCESS_TOKEN', 'GET'],
'GET_USER_DETAIL' : ['/cgi-bin/user/getuserdetail?access_token=ACCESS_TOKEN', 'POST'],
'GET_TICKET' : ['/cgi-bin/ticket/get?access_token=ACCESS_TOKEN', 'GET'],
'GET_JSAPI_TICKET' : ['/cgi-bin/get_jsapi_ticket?access_token=ACCESS_TOKEN', 'GET'],
'GET_CHECKIN_OPTION' : ['/cgi-bin/checkin/getcheckinoption?access_token=ACCESS_TOKEN', 'POST'],
'GET_CHECKIN_DATA' : ['/cgi-bin/checkin/getcheckindata?access_token=ACCESS_TOKEN', 'POST'],
'GET_APPROVAL_DATA': ['/cgi-bin/corp/getapprovaldata?access_token=ACCESS_TOKEN', 'POST'],
'GET_INVOICE_INFO' : ['/cgi-bin/card/invoice/reimburse/getinvoiceinfo?access_token=ACCESS_TOKEN', 'POST'],
'UPDATE_INVOICE_STATUS' :
['/cgi-bin/card/invoice/reimburse/updateinvoicestatus?access_token=ACCESS_TOKEN', 'POST'],
'BATCH_UPDATE_INVOICE_STATUS' :
['/cgi-bin/card/invoice/reimburse/updatestatusbatch?access_token=ACCESS_TOKEN', 'POST'],
'BATCH_GET_INVOICE_INFO' :
['/cgi-bin/card/invoice/reimburse/getinvoiceinfobatch?access_token=ACCESS_TOKEN', 'POST'],
'APP_CHAT_CREATE' : ['/cgi-bin/appchat/create?access_token=ACCESS_TOKEN', 'POST'],
'APP_CHAT_GET' : ['/cgi-bin/appchat/get?access_token=ACCESS_TOKEN', 'GET'],
'APP_CHAT_UPDATE' : ['/cgi-bin/appchat/update?access_token=ACCESS_TOKEN', 'POST'],
'APP_CHAT_SEND' : ['/cgi-bin/appchat/send?access_token=ACCESS_TOKEN', 'POST'],
'MINIPROGRAM_CODE_TO_SESSION_KEY' : ['/cgi-bin/miniprogram/jscode2session?access_token=ACCESS_TOKEN', 'GET'],
}
class CorpApi(AbstractApi) :
def __init__(self, corpid, secret) :
self.corpid = corpid
self.secret = secret
self.access_token = None
def getAccessToken(self) :
if self.access_token is None :
self.refreshAccessToken()
return self.access_token
def refreshAccessToken(self) :
response = self.httpCall(
CORP_API_TYPE['GET_ACCESS_TOKEN'],
{
'corpid' : self.corpid,
'corpsecret': self.secret,
})
self.access_token = response.get('access_token')

View File

@@ -0,0 +1 @@
"""ref: https://github.com/sbzhu/weworkapi_python"""