1
0
mirror of https://github.com/haiwen/seahub.git synced 2025-10-21 02:42:26 +00:00

[api2] Return 401 if api token is invalid

This commit is contained in:
zhengxie
2015-04-13 11:59:57 +08:00
parent 16b77837b8
commit 94dcfe338a
2 changed files with 38 additions and 11 deletions

View File

@@ -1,6 +1,8 @@
import datetime import datetime
import logging import logging
from rest_framework import status
from rest_framework.authentication import BaseAuthentication from rest_framework.authentication import BaseAuthentication
from rest_framework.exceptions import APIException
import seaserv import seaserv
from seahub.base.accounts import User from seahub.base.accounts import User
@@ -25,6 +27,14 @@ def within_ten_min(d1, d2):
HEADER_CLIENT_VERSION = 'HTTP_SEAFILE_CLEINT_VERSION' HEADER_CLIENT_VERSION = 'HTTP_SEAFILE_CLEINT_VERSION'
HEADER_PLATFORM_VERSION = 'HTTP_SEAFILE_PLATFORM_VERSION' HEADER_PLATFORM_VERSION = 'HTTP_SEAFILE_PLATFORM_VERSION'
class AuthenticationFailed(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = 'Incorrect authentication credentials.'
def __init__(self, detail=None):
self.detail = detail or self.default_detail
class TokenAuthentication(BaseAuthentication): class TokenAuthentication(BaseAuthentication):
""" """
Simple token based authentication. Simple token based authentication.
@@ -42,13 +52,17 @@ class TokenAuthentication(BaseAuthentication):
def authenticate(self, request): def authenticate(self, request):
auth = request.META.get('HTTP_AUTHORIZATION', '').split() auth = request.META.get('HTTP_AUTHORIZATION', '').split()
key = None if not auth or auth[0].lower() != 'token':
if len(auth) == 2 and auth[0].lower() == "token":
key = auth[1]
if not key:
return None return None
if len(auth) == 1:
msg = 'Invalid token header. No credentials provided.'
raise AuthenticationFailed(msg)
elif len(auth) > 2:
msg = 'Invalid token header. Token string should not contain spaces.'
raise AuthenticationFailed(msg)
key = auth[1]
ret = self.authenticate_v2(request, key) ret = self.authenticate_v2(request, key)
if ret: if ret:
return ret return ret
@@ -67,12 +81,12 @@ class TokenAuthentication(BaseAuthentication):
try: try:
token = Token.objects.get(key=key) token = Token.objects.get(key=key)
except Token.DoesNotExist: except Token.DoesNotExist:
return None raise AuthenticationFailed('Invalid token')
try: try:
user = User.objects.get(email=token.user) user = User.objects.get(email=token.user)
except User.DoesNotExist: except User.DoesNotExist:
return None raise AuthenticationFailed('User inactive or deleted')
if MULTI_TENANCY: if MULTI_TENANCY:
orgs = seaserv.get_orgs_by_user(token.user) orgs = seaserv.get_orgs_by_user(token.user)
@@ -88,12 +102,12 @@ class TokenAuthentication(BaseAuthentication):
try: try:
token = TokenV2.objects.get(key=key) token = TokenV2.objects.get(key=key)
except TokenV2.DoesNotExist: except TokenV2.DoesNotExist:
return None raise AuthenticationFailed('Invalid token')
try: try:
user = User.objects.get(email=token.user) user = User.objects.get(email=token.user)
except User.DoesNotExist: except User.DoesNotExist:
return None raise AuthenticationFailed('User inactive or deleted')
if MULTI_TENANCY: if MULTI_TENANCY:
orgs = seaserv.get_orgs_by_user(token.user) orgs = seaserv.get_orgs_by_user(token.user)

View File

@@ -18,9 +18,22 @@ def fake_ccnet_id():
return randstring(length=40) return randstring(length=40)
class AuthTest(ApiTestBase): class AuthTest(ApiTestBase):
"""This tests involves creating/deleting api tokens, so for this test we use """This tests involves creating/deleting api tokens, so for this test we
a specific auth token so that it won't affect other test cases. use a specific auth token so that it won't affect other test cases.
""" """
def test_auth_token_missing(self):
return self.get(AUTH_PING_URL, token=None, use_token=False,
expected=403)
def test_auth_token_is_empty(self):
return self.get(AUTH_PING_URL, token='', expected=401)
def test_auth_token_contains_space(self):
return self.get(AUTH_PING_URL, token='token with space', expected=401)
def test_random_auth_token(self):
return self.get(AUTH_PING_URL, token='randomtoken', expected=401)
def test_logout_device(self): def test_logout_device(self):
token = self._desktop_login() token = self._desktop_login()
self._do_auth_ping(token, expected=200) self._do_auth_ping(token, expected=200)