1
0
mirror of https://github.com/haiwen/seahub.git synced 2025-09-02 07:27:04 +00:00

add djangorestframework

This commit is contained in:
poet
2012-07-14 14:25:17 +08:00
parent 823b437091
commit f29db5b649
56 changed files with 8249 additions and 0 deletions

View File

@@ -81,6 +81,7 @@ TEMPLATE_DIRS = (
# Always use forward slashes, even on Windows. # Always use forward slashes, even on Windows.
# Don't forget to use absolute paths, not relative paths. # Don't forget to use absolute paths, not relative paths.
os.path.join(os.path.dirname(__file__), "templates"), os.path.join(os.path.dirname(__file__), "templates"),
os.path.join(os.path.dirname(__file__),'thirdpart/djangorestframework/templates'),
) )
TEMPLATE_CONTEXT_PROCESSORS = ( TEMPLATE_CONTEXT_PROCESSORS = (
@@ -111,6 +112,7 @@ INSTALLED_APPS = (
'seahub.group', 'seahub.group',
'seahub.share', 'seahub.share',
'seahub.subdomain', 'seahub.subdomain',
'seahub.api',
) )
AUTHENTICATION_BACKENDS = ( AUTHENTICATION_BACKENDS = (

View File

@@ -0,0 +1,3 @@
__version__ = '0.3.3'
VERSION = __version__ # synonym

View File

@@ -0,0 +1,101 @@
"""
The :mod:`authentication` module provides a set of pluggable authentication classes.
Authentication behavior is provided by mixing the :class:`mixins.AuthMixin` class into a :class:`View` class.
The set of authentication methods which are used is then specified by setting the
:attr:`authentication` attribute on the :class:`View` class, and listing a set of :class:`authentication` classes.
"""
from django.contrib.auth import authenticate
from djangorestframework.compat import CsrfViewMiddleware
import base64
__all__ = (
'BaseAuthentication',
'BasicAuthentication',
'UserLoggedInAuthentication'
)
class BaseAuthentication(object):
"""
All authentication classes should extend BaseAuthentication.
"""
def __init__(self, view):
"""
:class:`Authentication` classes are always passed the current view on creation.
"""
self.view = view
def authenticate(self, request):
"""
Authenticate the :obj:`request` and return a :obj:`User` or :const:`None`. [*]_
.. [*] The authentication context *will* typically be a :obj:`User`,
but it need not be. It can be any user-like object so long as the
permissions classes (see the :mod:`permissions` module) on the view can
handle the object and use it to determine if the request has the required
permissions or not.
This can be an important distinction if you're implementing some token
based authentication mechanism, where the authentication context
may be more involved than simply mapping to a :obj:`User`.
"""
return None
class BasicAuthentication(BaseAuthentication):
"""
Use HTTP Basic authentication.
"""
def authenticate(self, request):
"""
Returns a :obj:`User` if a correct username and password have been supplied
using HTTP Basic authentication. Otherwise returns :const:`None`.
"""
from django.utils.encoding import smart_unicode, DjangoUnicodeDecodeError
if 'HTTP_AUTHORIZATION' in request.META:
auth = request.META['HTTP_AUTHORIZATION'].split()
if len(auth) == 2 and auth[0].lower() == "basic":
try:
auth_parts = base64.b64decode(auth[1]).partition(':')
except TypeError:
return None
try:
uname, passwd = smart_unicode(auth_parts[0]), smart_unicode(auth_parts[2])
except DjangoUnicodeDecodeError:
return None
user = authenticate(username=uname, password=passwd)
if user is not None and user.is_active:
return user
return None
class UserLoggedInAuthentication(BaseAuthentication):
"""
Use Django's session framework for authentication.
"""
def authenticate(self, request):
"""
Returns a :obj:`User` if the request session currently has a logged in user.
Otherwise returns :const:`None`.
"""
self.view.DATA # Make sure our generic parsing runs first
if getattr(request, 'user', None) and request.user.is_active:
# Enforce CSRF validation for session based authentication.
resp = CsrfViewMiddleware().process_view(request, None, (), {})
if resp is None: # csrf passed
return request.user
return None
# TODO: TokenAuthentication, DigestAuthentication, OAuthAuthentication

View File

@@ -0,0 +1,459 @@
"""
The :mod:`compat` module provides support for backwards compatibility with older versions of django/python.
"""
import django
# cStringIO only if it's available, otherwise StringIO
try:
import cStringIO as StringIO
except ImportError:
import StringIO
# parse_qs from 'urlparse' module unless python 2.5, in which case from 'cgi'
try:
# python >= 2.6
from urlparse import parse_qs
except ImportError:
# python < 2.6
from cgi import parse_qs
# django.test.client.RequestFactory (Required for Django < 1.3)
try:
from django.test.client import RequestFactory
except ImportError:
from django.test import Client
from django.core.handlers.wsgi import WSGIRequest
# From: http://djangosnippets.org/snippets/963/
# Lovely stuff
class RequestFactory(Client):
"""
Class that lets you create mock :obj:`Request` objects for use in testing.
Usage::
rf = RequestFactory()
get_request = rf.get('/hello/')
post_request = rf.post('/submit/', {'foo': 'bar'})
This class re-uses the :class:`django.test.client.Client` interface. Of which
you can find the docs here__.
__ http://www.djangoproject.com/documentation/testing/#the-test-client
Once you have a `request` object you can pass it to any :func:`view` function,
just as if that :func:`view` had been hooked up using a URLconf.
"""
def request(self, **request):
"""
Similar to parent class, but returns the :obj:`request` object as soon as it
has created it.
"""
environ = {
'HTTP_COOKIE': self.cookies,
'PATH_INFO': '/',
'QUERY_STRING': '',
'REQUEST_METHOD': 'GET',
'SCRIPT_NAME': '',
'SERVER_NAME': 'testserver',
'SERVER_PORT': 80,
'SERVER_PROTOCOL': 'HTTP/1.1',
}
environ.update(self.defaults)
environ.update(request)
return WSGIRequest(environ)
# django.views.generic.View (Django >= 1.3)
try:
from django.views.generic import View
if not hasattr(View, 'head'):
# First implementation of Django class-based views did not include head method
# in base View class - https://code.djangoproject.com/ticket/15668
class ViewPlusHead(View):
def head(self, request, *args, **kwargs):
return self.get(request, *args, **kwargs)
View = ViewPlusHead
except ImportError:
from django import http
from django.utils.functional import update_wrapper
# from django.utils.log import getLogger
# from django.utils.decorators import classonlymethod
# logger = getLogger('django.request') - We'll just drop support for logger if running Django <= 1.2
# Might be nice to fix this up sometime to allow djangorestframework.compat.View to match 1.3's View more closely
class View(object):
"""
Intentionally simple parent class for all views. Only implements
dispatch-by-method and simple sanity checking.
"""
http_method_names = ['get', 'post', 'put', 'delete', 'head', 'options', 'trace']
def __init__(self, **kwargs):
"""
Constructor. Called in the URLconf; can contain helpful extra
keyword arguments, and other things.
"""
# Go through keyword arguments, and either save their values to our
# instance, or raise an error.
for key, value in kwargs.iteritems():
setattr(self, key, value)
# @classonlymethod - We'll just us classmethod instead if running Django <= 1.2
@classmethod
def as_view(cls, **initkwargs):
"""
Main entry point for a request-response process.
"""
# sanitize keyword arguments
for key in initkwargs:
if key in cls.http_method_names:
raise TypeError(u"You tried to pass in the %s method name as a "
u"keyword argument to %s(). Don't do that."
% (key, cls.__name__))
if not hasattr(cls, key):
raise TypeError(u"%s() received an invalid keyword %r" % (
cls.__name__, key))
def view(request, *args, **kwargs):
self = cls(**initkwargs)
return self.dispatch(request, *args, **kwargs)
# take name and docstring from class
update_wrapper(view, cls, updated=())
# and possible attributes set by decorators
# like csrf_exempt from dispatch
update_wrapper(view, cls.dispatch, assigned=())
return view
def dispatch(self, request, *args, **kwargs):
# Try to dispatch to the right method; if a method doesn't exist,
# defer to the error handler. Also defer to the error handler if the
# request method isn't on the approved list.
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(), self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
self.request = request
self.args = args
self.kwargs = kwargs
return handler(request, *args, **kwargs)
def http_method_not_allowed(self, request, *args, **kwargs):
allowed_methods = [m for m in self.http_method_names if hasattr(self, m)]
#logger.warning('Method Not Allowed (%s): %s' % (request.method, request.path),
# extra={
# 'status_code': 405,
# 'request': self.request
# }
#)
return http.HttpResponseNotAllowed(allowed_methods)
def head(self, request, *args, **kwargs):
return self.get(request, *args, **kwargs)
# PUT, DELETE do not require CSRF until 1.4. They should. Make it better.
if django.VERSION >= (1, 4):
from django.middleware.csrf import CsrfViewMiddleware
else:
import hashlib
import re
import random
import logging
import urlparse
from django.conf import settings
from django.core.urlresolvers import get_callable
try:
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
logger = logging.getLogger('django.request')
if not logger.handlers:
logger.addHandler(NullHandler())
def same_origin(url1, url2):
"""
Checks if two URLs are 'same-origin'
"""
p1, p2 = urlparse.urlparse(url1), urlparse.urlparse(url2)
return p1[0:2] == p2[0:2]
def constant_time_compare(val1, val2):
"""
Returns True if the two strings are equal, False otherwise.
The time taken is independent of the number of characters that match.
"""
if len(val1) != len(val2):
return False
result = 0
for x, y in zip(val1, val2):
result |= ord(x) ^ ord(y)
return result == 0
# Use the system (hardware-based) random number generator if it exists.
if hasattr(random, 'SystemRandom'):
randrange = random.SystemRandom().randrange
else:
randrange = random.randrange
_MAX_CSRF_KEY = 18446744073709551616L # 2 << 63
REASON_NO_REFERER = "Referer checking failed - no Referer."
REASON_BAD_REFERER = "Referer checking failed - %s does not match %s."
REASON_NO_CSRF_COOKIE = "CSRF cookie not set."
REASON_BAD_TOKEN = "CSRF token missing or incorrect."
def _get_failure_view():
"""
Returns the view to be used for CSRF rejections
"""
return get_callable(settings.CSRF_FAILURE_VIEW)
def _get_new_csrf_key():
return hashlib.md5("%s%s" % (randrange(0, _MAX_CSRF_KEY), settings.SECRET_KEY)).hexdigest()
def get_token(request):
"""
Returns the the CSRF token required for a POST form. The token is an
alphanumeric value.
A side effect of calling this function is to make the the csrf_protect
decorator and the CsrfViewMiddleware add a CSRF cookie and a 'Vary: Cookie'
header to the outgoing response. For this reason, you may need to use this
function lazily, as is done by the csrf context processor.
"""
request.META["CSRF_COOKIE_USED"] = True
return request.META.get("CSRF_COOKIE", None)
def _sanitize_token(token):
# Allow only alphanum, and ensure we return a 'str' for the sake of the post
# processing middleware.
token = re.sub('[^a-zA-Z0-9]', '', str(token.decode('ascii', 'ignore')))
if token == "":
# In case the cookie has been truncated to nothing at some point.
return _get_new_csrf_key()
else:
return token
class CsrfViewMiddleware(object):
"""
Middleware that requires a present and correct csrfmiddlewaretoken
for POST requests that have a CSRF cookie, and sets an outgoing
CSRF cookie.
This middleware should be used in conjunction with the csrf_token template
tag.
"""
# The _accept and _reject methods currently only exist for the sake of the
# requires_csrf_token decorator.
def _accept(self, request):
# Avoid checking the request twice by adding a custom attribute to
# request. This will be relevant when both decorator and middleware
# are used.
request.csrf_processing_done = True
return None
def _reject(self, request, reason):
return _get_failure_view()(request, reason=reason)
def process_view(self, request, callback, callback_args, callback_kwargs):
if getattr(request, 'csrf_processing_done', False):
return None
try:
csrf_token = _sanitize_token(request.COOKIES[settings.CSRF_COOKIE_NAME])
# Use same token next time
request.META['CSRF_COOKIE'] = csrf_token
except KeyError:
csrf_token = None
# Generate token and store it in the request, so it's available to the view.
request.META["CSRF_COOKIE"] = _get_new_csrf_key()
# Wait until request.META["CSRF_COOKIE"] has been manipulated before
# bailing out, so that get_token still works
if getattr(callback, 'csrf_exempt', False):
return None
# Assume that anything not defined as 'safe' by RC2616 needs protection.
if request.method not in ('GET', 'HEAD', 'OPTIONS', 'TRACE'):
if getattr(request, '_dont_enforce_csrf_checks', False):
# Mechanism to turn off CSRF checks for test suite. It comes after
# the creation of CSRF cookies, so that everything else continues to
# work exactly the same (e.g. cookies are sent etc), but before the
# any branches that call reject()
return self._accept(request)
if request.is_secure():
# Suppose user visits http://example.com/
# An active network attacker,(man-in-the-middle, MITM) sends a
# POST form which targets https://example.com/detonate-bomb/ and
# submits it via javascript.
#
# The attacker will need to provide a CSRF cookie and token, but
# that is no problem for a MITM and the session independent
# nonce we are using. So the MITM can circumvent the CSRF
# protection. This is true for any HTTP connection, but anyone
# using HTTPS expects better! For this reason, for
# https://example.com/ we need additional protection that treats
# http://example.com/ as completely untrusted. Under HTTPS,
# Barth et al. found that the Referer header is missing for
# same-domain requests in only about 0.2% of cases or less, so
# we can use strict Referer checking.
referer = request.META.get('HTTP_REFERER')
if referer is None:
logger.warning('Forbidden (%s): %s' % (REASON_NO_REFERER, request.path),
extra={
'status_code': 403,
'request': request,
}
)
return self._reject(request, REASON_NO_REFERER)
# Note that request.get_host() includes the port
good_referer = 'https://%s/' % request.get_host()
if not same_origin(referer, good_referer):
reason = REASON_BAD_REFERER % (referer, good_referer)
logger.warning('Forbidden (%s): %s' % (reason, request.path),
extra={
'status_code': 403,
'request': request,
}
)
return self._reject(request, reason)
if csrf_token is None:
# No CSRF cookie. For POST requests, we insist on a CSRF cookie,
# and in this way we can avoid all CSRF attacks, including login
# CSRF.
logger.warning('Forbidden (%s): %s' % (REASON_NO_CSRF_COOKIE, request.path),
extra={
'status_code': 403,
'request': request,
}
)
return self._reject(request, REASON_NO_CSRF_COOKIE)
# check non-cookie token for match
request_csrf_token = ""
if request.method == "POST":
request_csrf_token = request.POST.get('csrfmiddlewaretoken', '')
if request_csrf_token == "":
# Fall back to X-CSRFToken, to make things easier for AJAX,
# and possible for PUT/DELETE
request_csrf_token = request.META.get('HTTP_X_CSRFTOKEN', '')
if not constant_time_compare(request_csrf_token, csrf_token):
logger.warning('Forbidden (%s): %s' % (REASON_BAD_TOKEN, request.path),
extra={
'status_code': 403,
'request': request,
}
)
return self._reject(request, REASON_BAD_TOKEN)
return self._accept(request)
# Markdown is optional
try:
import markdown
class CustomSetextHeaderProcessor(markdown.blockprocessors.BlockProcessor):
"""
Class for markdown < 2.1
Override `markdown`'s :class:`SetextHeaderProcessor`, so that ==== headers are <h2> and ---- heade
We use <h1> for the resource name.
"""
import re
# Detect Setext-style header. Must be first 2 lines of block.
RE = re.compile(r'^.*?\n[=-]{3,}', re.MULTILINE)
def test(self, parent, block):
return bool(self.RE.match(block))
def run(self, parent, blocks):
lines = blocks.pop(0).split('\n')
# Determine level. ``=`` is 1 and ``-`` is 2.
if lines[1].startswith('='):
level = 2
else:
level = 3
h = markdown.etree.SubElement(parent, 'h%d' % level)
h.text = lines[0].strip()
if len(lines) > 2:
# Block contains additional lines. Add to master blocks for later.
blocks.insert(0, '\n'.join(lines[2:]))
def apply_markdown(text):
"""
Simple wrapper around :func:`markdown.markdown` to set the base level
of '#' style headers to <h2>.
"""
extensions = ['headerid(level=2)']
safe_mode = False,
if markdown.version_info < (2, 1):
output_format = markdown.DEFAULT_OUTPUT_FORMAT
md = markdown.Markdown(extensions=markdown.load_extensions(extensions),
safe_mode=safe_mode,
output_format=output_format)
md.parser.blockprocessors['setextheader'] = CustomSetextHeaderProcessor(md.parser)
else:
md = markdown.Markdown(extensions=extensions, safe_mode=safe_mode)
return md.convert(text)
except ImportError:
apply_markdown = None
# Yaml is optional
try:
import yaml
except ImportError:
yaml = None
import unittest
try:
import unittest.skip
except ImportError: # python < 2.7
from unittest import TestCase
import functools
def skip(reason):
# Pasted from py27/lib/unittest/case.py
"""
Unconditionally skip a test.
"""
def decorator(test_item):
if not (isinstance(test_item, type) and issubclass(test_item, TestCase)):
@functools.wraps(test_item)
def skip_wrapper(*args, **kwargs):
pass
test_item = skip_wrapper
test_item.__unittest_skip__ = True
test_item.__unittest_skip_why__ = reason
return test_item
return decorator
unittest.skip = skip

View File

@@ -0,0 +1,760 @@
"""
The :mod:`mixins` module provides a set of reusable `mixin`
classes that can be added to a `View`.
"""
from django.contrib.auth.models import AnonymousUser
from django.core.paginator import Paginator
from django.db.models.fields.related import ForeignKey
from django.http import HttpResponse
from urlobject import URLObject
from djangorestframework import status
from djangorestframework.renderers import BaseRenderer
from djangorestframework.resources import Resource, FormResource, ModelResource
from djangorestframework.response import Response, ErrorResponse
from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX
from djangorestframework.utils.mediatypes import is_form_media_type, order_by_precedence
from StringIO import StringIO
__all__ = (
# Base behavior mixins
'RequestMixin',
'ResponseMixin',
'AuthMixin',
'ResourceMixin',
# Reverse URL lookup behavior
'InstanceMixin',
# Model behavior mixins
'ReadModelMixin',
'CreateModelMixin',
'UpdateModelMixin',
'DeleteModelMixin',
'ListModelMixin'
)
########## Request Mixin ##########
class RequestMixin(object):
"""
`Mixin` class to provide request parsing behavior.
"""
_USE_FORM_OVERLOADING = True
_METHOD_PARAM = '_method'
_CONTENTTYPE_PARAM = '_content_type'
_CONTENT_PARAM = '_content'
parsers = ()
"""
The set of request parsers that the view can handle.
Should be a tuple/list of classes as described in the :mod:`parsers` module.
"""
@property
def method(self):
"""
Returns the HTTP method.
This should be used instead of just reading :const:`request.method`, as it allows the `method`
to be overridden by using a hidden `form` field on a form POST request.
"""
if not hasattr(self, '_method'):
self._load_method_and_content_type()
return self._method
@property
def content_type(self):
"""
Returns the content type header.
This should be used instead of ``request.META.get('HTTP_CONTENT_TYPE')``,
as it allows the content type to be overridden by using a hidden form
field on a form POST request.
"""
if not hasattr(self, '_content_type'):
self._load_method_and_content_type()
return self._content_type
@property
def DATA(self):
"""
Parses the request body and returns the data.
Similar to ``request.POST``, except that it handles arbitrary parsers,
and also works on methods other than POST (eg PUT).
"""
if not hasattr(self, '_data'):
self._load_data_and_files()
return self._data
@property
def FILES(self):
"""
Parses the request body and returns the files.
Similar to ``request.FILES``, except that it handles arbitrary parsers,
and also works on methods other than POST (eg PUT).
"""
if not hasattr(self, '_files'):
self._load_data_and_files()
return self._files
def _load_data_and_files(self):
"""
Parse the request content into self.DATA and self.FILES.
"""
if not hasattr(self, '_content_type'):
self._load_method_and_content_type()
if not hasattr(self, '_data'):
(self._data, self._files) = self._parse(self._get_stream(), self._content_type)
def _load_method_and_content_type(self):
"""
Set the method and content_type, and then check if they've been overridden.
"""
self._method = self.request.method
self._content_type = self.request.META.get('HTTP_CONTENT_TYPE', self.request.META.get('CONTENT_TYPE', ''))
self._perform_form_overloading()
def _get_stream(self):
"""
Returns an object that may be used to stream the request content.
"""
request = self.request
try:
content_length = int(request.META.get('CONTENT_LENGTH', request.META.get('HTTP_CONTENT_LENGTH')))
except (ValueError, TypeError):
content_length = 0
# TODO: Add 1.3's LimitedStream to compat and use that.
# NOTE: Currently only supports parsing request body as a stream with 1.3
if content_length == 0:
return None
elif hasattr(request, 'read'):
return request
return StringIO(request.raw_post_data)
def _perform_form_overloading(self):
"""
If this is a form POST request, then we need to check if the method and content/content_type have been
overridden by setting them in hidden form fields or not.
"""
# We only need to use form overloading on form POST requests.
if not self._USE_FORM_OVERLOADING or self._method != 'POST' or not is_form_media_type(self._content_type):
return
# At this point we're committed to parsing the request as form data.
self._data = data = self.request.POST.copy()
self._files = self.request.FILES
# Method overloading - change the method and remove the param from the content.
if self._METHOD_PARAM in data:
# NOTE: unlike `get`, `pop` on a `QueryDict` seems to return a list of values.
self._method = self._data.pop(self._METHOD_PARAM)[0].upper()
# Content overloading - modify the content type, and re-parse.
if self._CONTENT_PARAM in data and self._CONTENTTYPE_PARAM in data:
self._content_type = self._data.pop(self._CONTENTTYPE_PARAM)[0]
stream = StringIO(self._data.pop(self._CONTENT_PARAM)[0])
(self._data, self._files) = self._parse(stream, self._content_type)
def _parse(self, stream, content_type):
"""
Parse the request content.
May raise a 415 ErrorResponse (Unsupported Media Type), or a 400 ErrorResponse (Bad Request).
"""
if stream is None or content_type is None:
return (None, None)
parsers = as_tuple(self.parsers)
for parser_cls in parsers:
parser = parser_cls(self)
if parser.can_handle_request(content_type):
return parser.parse(stream)
raise ErrorResponse(status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
{'error': 'Unsupported media type in request \'%s\'.' %
content_type})
@property
def _parsed_media_types(self):
"""
Return a list of all the media types that this view can parse.
"""
return [parser.media_type for parser in self.parsers]
@property
def _default_parser(self):
"""
Return the view's default parser class.
"""
return self.parsers[0]
########## ResponseMixin ##########
class ResponseMixin(object):
"""
Adds behavior for pluggable `Renderers` to a :class:`views.View` class.
Default behavior is to use standard HTTP Accept header content negotiation.
Also supports overriding the content type by specifying an ``_accept=`` parameter in the URL.
Ignores Accept headers from Internet Explorer user agents and uses a sensible browser Accept header instead.
"""
_ACCEPT_QUERY_PARAM = '_accept' # Allow override of Accept header in URL query params
_IGNORE_IE_ACCEPT_HEADER = True
renderers = ()
"""
The set of response renderers that the view can handle.
Should be a tuple/list of classes as described in the :mod:`renderers` module.
"""
def get_renderers(self):
"""
Return an iterable of available renderers. Override if you want to change
this list at runtime, say depending on what settings you have enabled.
"""
return self.renderers
# TODO: wrap this behavior around dispatch(), ensuring it works
# out of the box with existing Django classes that use render_to_response.
def render(self, response):
"""
Takes a :obj:`Response` object and returns an :obj:`HttpResponse`.
"""
self.response = response
try:
renderer, media_type = self._determine_renderer(self.request)
except ErrorResponse, exc:
renderer = self._default_renderer(self)
media_type = renderer.media_type
response = exc.response
# Set the media type of the response
# Note that the renderer *could* override it in .render() if required.
response.media_type = renderer.media_type
# Serialize the response content
if response.has_content_body:
content = renderer.render(response.cleaned_content, media_type)
else:
content = renderer.render()
# Build the HTTP Response
resp = HttpResponse(content, mimetype=response.media_type, status=response.status)
for (key, val) in response.headers.items():
resp[key] = val
return resp
def _determine_renderer(self, request):
"""
Determines the appropriate renderer for the output, given the client's 'Accept' header,
and the :attr:`renderers` set on this class.
Returns a 2-tuple of `(renderer, media_type)`
See: RFC 2616, Section 14 - http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html
"""
if self._ACCEPT_QUERY_PARAM and request.GET.get(self._ACCEPT_QUERY_PARAM, None):
# Use _accept parameter override
accept_list = [request.GET.get(self._ACCEPT_QUERY_PARAM)]
elif (self._IGNORE_IE_ACCEPT_HEADER and
'HTTP_USER_AGENT' in request.META and
MSIE_USER_AGENT_REGEX.match(request.META['HTTP_USER_AGENT'])):
# Ignore MSIE's broken accept behavior and do something sensible instead
accept_list = ['text/html', '*/*']
elif 'HTTP_ACCEPT' in request.META:
# Use standard HTTP Accept negotiation
accept_list = [token.strip() for token in request.META['HTTP_ACCEPT'].split(',')]
else:
# No accept header specified
accept_list = ['*/*']
# Check the acceptable media types against each renderer,
# attempting more specific media types first
# NB. The inner loop here isn't as bad as it first looks :)
# Worst case is we're looping over len(accept_list) * len(self.renderers)
renderers = [renderer_cls(self) for renderer_cls in self.get_renderers()]
for accepted_media_type_lst in order_by_precedence(accept_list):
for renderer in renderers:
for accepted_media_type in accepted_media_type_lst:
if renderer.can_handle_response(accepted_media_type):
return renderer, accepted_media_type
# No acceptable renderers were found
raise ErrorResponse(status.HTTP_406_NOT_ACCEPTABLE,
{'detail': 'Could not satisfy the client\'s Accept header',
'available_types': self._rendered_media_types})
@property
def _rendered_media_types(self):
"""
Return an list of all the media types that this view can render.
"""
return [renderer.media_type for renderer in self.renderers]
@property
def _rendered_formats(self):
"""
Return a list of all the formats that this view can render.
"""
return [renderer.format for renderer in self.renderers]
@property
def _default_renderer(self):
"""
Return the view's default renderer class.
"""
return self.renderers[0]
########## Auth Mixin ##########
class AuthMixin(object):
"""
Simple :class:`mixin` class to add authentication and permission checking to a :class:`View` class.
"""
authentication = ()
"""
The set of authentication types that this view can handle.
Should be a tuple/list of classes as described in the :mod:`authentication` module.
"""
permissions = ()
"""
The set of permissions that will be enforced on this view.
Should be a tuple/list of classes as described in the :mod:`permissions` module.
"""
@property
def user(self):
"""
Returns the :obj:`user` for the current request, as determined by the set of
:class:`authentication` classes applied to the :class:`View`.
"""
if not hasattr(self, '_user'):
self._user = self._authenticate()
return self._user
def _authenticate(self):
"""
Attempt to authenticate the request using each authentication class in turn.
Returns a ``User`` object, which may be ``AnonymousUser``.
"""
for authentication_cls in self.authentication:
authentication = authentication_cls(self)
user = authentication.authenticate(self.request)
if user:
return user
return AnonymousUser()
# TODO: wrap this behavior around dispatch()
def _check_permissions(self):
"""
Check user permissions and either raise an ``ErrorResponse`` or return.
"""
user = self.user
for permission_cls in self.permissions:
permission = permission_cls(self)
permission.check_permission(user)
########## Resource Mixin ##########
class ResourceMixin(object):
"""
Provides request validation and response filtering behavior.
Should be a class as described in the :mod:`resources` module.
The :obj:`resource` is an object that maps a view onto it's representation on the server.
It provides validation on the content of incoming requests,
and filters the object representation into a serializable object for the response.
"""
resource = None
@property
def CONTENT(self):
"""
Returns the cleaned, validated request content.
May raise an :class:`response.ErrorResponse` with status code 400 (Bad Request).
"""
if not hasattr(self, '_content'):
self._content = self.validate_request(self.DATA, self.FILES)
return self._content
@property
def PARAMS(self):
"""
Returns the cleaned, validated query parameters.
May raise an :class:`response.ErrorResponse` with status code 400 (Bad Request).
"""
return self.validate_request(self.request.GET)
@property
def _resource(self):
if self.resource:
return self.resource(self)
elif getattr(self, 'model', None):
return ModelResource(self)
elif getattr(self, 'form', None):
return FormResource(self)
elif getattr(self, '%s_form' % self.method.lower(), None):
return FormResource(self)
return Resource(self)
def validate_request(self, data, files=None):
"""
Given the request *data* and optional *files*, return the cleaned, validated content.
May raise an :class:`response.ErrorResponse` with status code 400 (Bad Request) on failure.
"""
return self._resource.validate_request(data, files)
def filter_response(self, obj):
"""
Given the response content, filter it into a serializable object.
"""
return self._resource.filter_response(obj)
def get_bound_form(self, content=None, method=None):
if hasattr(self._resource, 'get_bound_form'):
return self._resource.get_bound_form(content, method=method)
else:
return None
##########
class InstanceMixin(object):
"""
`Mixin` class that is used to identify a `View` class as being the canonical identifier
for the resources it is mapped to.
"""
@classmethod
def as_view(cls, **initkwargs):
"""
Store the callable object on the resource class that has been associated with this view.
"""
view = super(InstanceMixin, cls).as_view(**initkwargs)
resource = getattr(cls(**initkwargs), 'resource', None)
if resource:
# We do a little dance when we store the view callable...
# we need to store it wrapped in a 1-tuple, so that inspect will treat it
# as a function when we later look it up (rather than turning it into a method).
# This makes sure our URL reversing works ok.
resource.view_callable = (view,)
return view
########## Model Mixins ##########
class ModelMixin(object):
""" Implements mechanisms used by other classes (like *ModelMixin group) to
define a query that represents Model instances the Mixin is working with.
If a *ModelMixin is going to retrive an instance (or queryset) using args and kwargs
passed by as URL arguments, it should provied arguments to objects.get and objects.filter
methods wrapped in by `build_query`
If a *ModelMixin is going to create/update an instance get_instance_data
handles the instance data creation/preaparation.
"""
queryset = None
def get_query_kwargs(self, *args, **kwargs):
"""
Return a dict of kwargs that will be used to build the
model instance retrieval or to filter querysets.
"""
kwargs = dict(kwargs)
# If the URLconf includes a .(?P<format>\w+) pattern to match against
# a .json, .xml suffix, then drop the 'format' kwarg before
# constructing the query.
if BaseRenderer._FORMAT_QUERY_PARAM in kwargs:
del kwargs[BaseRenderer._FORMAT_QUERY_PARAM]
return kwargs
def get_instance_data(self, model, content, **kwargs):
"""
Returns the dict with the data for model instance creation/update.
Arguments:
- model: model class (django.db.models.Model subclass) to work with
- content: a dictionary with instance data
- kwargs: a dict of URL provided keyword arguments
The create/update queries are created basicly with the contet provided
with POST/PUT HTML methods and kwargs passed in the URL. This methods
simply merges the URL data and the content preaparing the ready-to-use
data dictionary.
"""
tmp = dict(kwargs)
for field in model._meta.fields:
if isinstance(field, ForeignKey) and field.name in tmp:
# translate 'related_field' kwargs into 'related_field_id'
tmp[field.name + '_id'] = tmp[field.name]
del tmp[field.name]
all_kw_args = dict(content.items() + tmp.items())
return all_kw_args
def get_instance(self, **kwargs):
"""
Get a model instance for read/update/delete requests.
"""
return self.get_queryset().get(**kwargs)
def get_queryset(self):
"""
Return the queryset for this view.
"""
return getattr(self.resource, 'queryset',
self.resource.model.objects.all())
def get_ordering(self):
"""
Return the ordering for this view.
"""
return getattr(self.resource, 'ordering', None)
class ReadModelMixin(ModelMixin):
"""
Behavior to read a `model` instance on GET requests
"""
def get(self, request, *args, **kwargs):
model = self.resource.model
query_kwargs = self.get_query_kwargs(request, *args, **kwargs)
try:
self.model_instance = self.get_instance(**query_kwargs)
except model.DoesNotExist:
raise ErrorResponse(status.HTTP_404_NOT_FOUND)
return self.model_instance
class CreateModelMixin(ModelMixin):
"""
Behavior to create a `model` instance on POST requests
"""
def post(self, request, *args, **kwargs):
model = self.resource.model
# Copy the dict to keep self.CONTENT intact
content = dict(self.CONTENT)
m2m_data = {}
for field in model._meta.many_to_many:
if field.name in content:
m2m_data[field.name] = (
field.m2m_reverse_field_name(), content[field.name]
)
del content[field.name]
instance = model(**self.get_instance_data(model, content, *args, **kwargs))
instance.save()
for fieldname in m2m_data:
manager = getattr(instance, fieldname)
if hasattr(manager, 'add'):
manager.add(*m2m_data[fieldname][1])
else:
data = {}
data[manager.source_field_name] = instance
for related_item in m2m_data[fieldname][1]:
data[m2m_data[fieldname][0]] = related_item
manager.through(**data).save()
headers = {}
if hasattr(instance, 'get_absolute_url'):
headers['Location'] = self.resource(self).url(instance)
return Response(status.HTTP_201_CREATED, instance, headers)
class UpdateModelMixin(ModelMixin):
"""
Behavior to update a `model` instance on PUT requests
"""
def put(self, request, *args, **kwargs):
model = self.resource.model
query_kwargs = self.get_query_kwargs(request, *args, **kwargs)
# TODO: update on the url of a non-existing resource url doesn't work
# correctly at the moment - will end up with a new url
try:
self.model_instance = self.get_instance(**query_kwargs)
for (key, val) in self.CONTENT.items():
setattr(self.model_instance, key, val)
except model.DoesNotExist:
self.model_instance = model(**self.get_instance_data(model, self.CONTENT, *args, **kwargs))
self.model_instance.save()
return self.model_instance
class DeleteModelMixin(ModelMixin):
"""
Behavior to delete a `model` instance on DELETE requests
"""
def delete(self, request, *args, **kwargs):
model = self.resource.model
query_kwargs = self.get_query_kwargs(request, *args, **kwargs)
try:
instance = self.get_instance(**query_kwargs)
except model.DoesNotExist:
raise ErrorResponse(status.HTTP_404_NOT_FOUND, None, {})
instance.delete()
return
class ListModelMixin(ModelMixin):
"""
Behavior to list a set of `model` instances on GET requests
"""
def get(self, request, *args, **kwargs):
queryset = self.get_queryset()
ordering = self.get_ordering()
query_kwargs = self.get_query_kwargs(request, *args, **kwargs)
queryset = queryset.filter(**query_kwargs)
if ordering:
queryset = queryset.order_by(*ordering)
return queryset
########## Pagination Mixins ##########
class PaginatorMixin(object):
"""
Adds pagination support to GET requests
Obviously should only be used on lists :)
A default limit can be set by setting `limit` on the object. This will also
be used as the maximum if the client sets the `limit` GET param
"""
limit = 20
def get_limit(self):
"""
Helper method to determine what the `limit` should be
"""
try:
limit = int(self.request.GET.get('limit', self.limit))
return min(limit, self.limit)
except ValueError:
return self.limit
def url_with_page_number(self, page_number):
"""
Constructs a url used for getting the next/previous urls
"""
url = URLObject(self.request.get_full_path())
url = url.set_query_param('page', str(page_number))
limit = self.get_limit()
if limit != self.limit:
url = url.set_query_param('limit', str(limit))
return url
def next(self, page):
"""
Returns a url to the next page of results (if any)
"""
if not page.has_next():
return None
return self.url_with_page_number(page.next_page_number())
def previous(self, page):
""" Returns a url to the previous page of results (if any) """
if not page.has_previous():
return None
return self.url_with_page_number(page.previous_page_number())
def serialize_page_info(self, page):
"""
This is some useful information that is added to the response
"""
return {
'next': self.next(page),
'page': page.number,
'pages': page.paginator.num_pages,
'per_page': self.get_limit(),
'previous': self.previous(page),
'total': page.paginator.count,
}
def filter_response(self, obj):
"""
Given the response content, paginate and then serialize.
The response is modified to include to useful data relating to the number
of objects, number of pages, next/previous urls etc. etc.
The serialised objects are put into `results` on this new, modified
response
"""
# We don't want to paginate responses for anything other than GET requests
if self.method.upper() != 'GET':
return self._resource.filter_response(obj)
paginator = Paginator(obj, self.get_limit())
try:
page_num = int(self.request.GET.get('page', '1'))
except ValueError:
raise ErrorResponse(status.HTTP_404_NOT_FOUND,
{'detail': 'That page contains no results'})
if page_num not in paginator.page_range:
raise ErrorResponse(status.HTTP_404_NOT_FOUND,
{'detail': 'That page contains no results'})
page = paginator.page(page_num)
serialized_object_list = self._resource.filter_response(page.object_list)
serialized_page_info = self.serialize_page_info(page)
serialized_page_info['results'] = serialized_object_list
return serialized_page_info

View File

@@ -0,0 +1 @@
# Just to keep things like ./manage.py test happy

View File

@@ -0,0 +1,252 @@
"""
Django supports parsing the content of an HTTP request, but only for form POST requests.
That behavior is sufficient for dealing with standard HTML forms, but it doesn't map well
to general HTTP requests.
We need a method to be able to:
1.) Determine the parsed content on a request for methods other than POST (eg typically also PUT)
2.) Determine the parsed content on a request for media types other than application/x-www-form-urlencoded
and multipart/form-data. (eg also handle multipart/json)
"""
from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError
from django.utils import simplejson as json
from djangorestframework import status
from djangorestframework.compat import yaml
from djangorestframework.response import ErrorResponse
from djangorestframework.utils.mediatypes import media_type_matches
from xml.etree import ElementTree as ET
import datetime
import decimal
__all__ = (
'BaseParser',
'JSONParser',
'PlainTextParser',
'FormParser',
'MultiPartParser',
'YAMLParser',
'XMLParser'
)
class BaseParser(object):
"""
All parsers should extend :class:`BaseParser`, specifying a :attr:`media_type` attribute,
and overriding the :meth:`parse` method.
"""
media_type = None
def __init__(self, view):
"""
Initialize the parser with the ``View`` instance as state,
in case the parser needs to access any metadata on the :obj:`View` object.
"""
self.view = view
def can_handle_request(self, content_type):
"""
Returns :const:`True` if this parser is able to deal with the given *content_type*.
The default implementation for this function is to check the *content_type*
argument against the :attr:`media_type` attribute set on the class to see if
they match.
This may be overridden to provide for other behavior, but typically you'll
instead want to just set the :attr:`media_type` attribute on the class.
"""
return media_type_matches(self.media_type, content_type)
def parse(self, stream):
"""
Given a *stream* to read from, return the deserialized output.
Should return a 2-tuple of (data, files).
"""
raise NotImplementedError("BaseParser.parse() Must be overridden to be implemented.")
class JSONParser(BaseParser):
"""
Parses JSON-serialized data.
"""
media_type = 'application/json'
def parse(self, stream):
"""
Returns a 2-tuple of `(data, files)`.
`data` will be an object which is the parsed content of the response.
`files` will always be `None`.
"""
try:
return (json.load(stream), None)
except ValueError, exc:
raise ErrorResponse(status.HTTP_400_BAD_REQUEST,
{'detail': 'JSON parse error - %s' % unicode(exc)})
if yaml:
class YAMLParser(BaseParser):
"""
Parses YAML-serialized data.
"""
media_type = 'application/yaml'
def parse(self, stream):
"""
Returns a 2-tuple of `(data, files)`.
`data` will be an object which is the parsed content of the response.
`files` will always be `None`.
"""
try:
return (yaml.safe_load(stream), None)
except ValueError, exc:
raise ErrorResponse(status.HTTP_400_BAD_REQUEST,
{'detail': 'YAML parse error - %s' % unicode(exc)})
else:
YAMLParser = None
class PlainTextParser(BaseParser):
"""
Plain text parser.
"""
media_type = 'text/plain'
def parse(self, stream):
"""
Returns a 2-tuple of `(data, files)`.
`data` will simply be a string representing the body of the request.
`files` will always be `None`.
"""
return (stream.read(), None)
class FormParser(BaseParser):
"""
Parser for form data.
"""
media_type = 'application/x-www-form-urlencoded'
def parse(self, stream):
"""
Returns a 2-tuple of `(data, files)`.
`data` will be a :class:`QueryDict` containing all the form parameters.
`files` will always be :const:`None`.
"""
data = QueryDict(stream.read())
return (data, None)
class MultiPartParser(BaseParser):
"""
Parser for multipart form data, which may include file data.
"""
media_type = 'multipart/form-data'
def parse(self, stream):
"""
Returns a 2-tuple of `(data, files)`.
`data` will be a :class:`QueryDict` containing all the form parameters.
`files` will be a :class:`QueryDict` containing all the form files.
"""
upload_handlers = self.view.request._get_upload_handlers()
try:
django_parser = DjangoMultiPartParser(self.view.request.META, stream, upload_handlers)
except MultiPartParserError, exc:
raise ErrorResponse(status.HTTP_400_BAD_REQUEST,
{'detail': 'multipart parse error - %s' % unicode(exc)})
return django_parser.parse()
class XMLParser(BaseParser):
"""
XML parser.
"""
media_type = 'application/xml'
def parse(self, stream):
"""
Returns a 2-tuple of `(data, files)`.
`data` will simply be a string representing the body of the request.
`files` will always be `None`.
"""
tree = ET.parse(stream)
data = self._xml_convert(tree.getroot())
return (data, None)
def _xml_convert(self, element):
"""
convert the xml `element` into the corresponding python object
"""
children = element.getchildren()
if len(children) == 0:
return self._type_convert(element.text)
else:
# if the fist child tag is list-item means all children are list-item
if children[0].tag == "list-item":
data = []
for child in children:
data.append(self._xml_convert(child))
else:
data = {}
for child in children:
data[child.tag] = self._xml_convert(child)
return data
def _type_convert(self, value):
"""
Converts the value returned by the XMl parse into the equivalent
Python type
"""
if value is None:
return value
try:
return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
except ValueError:
pass
try:
return int(value)
except ValueError:
pass
try:
return decimal.Decimal(value)
except decimal.InvalidOperation:
pass
return value
DEFAULT_PARSERS = (
JSONParser,
FormParser,
MultiPartParser,
XMLParser
)
if YAMLParser:
DEFAULT_PARSERS += (YAMLParser,)

View File

@@ -0,0 +1,267 @@
"""
The :mod:`permissions` module bundles a set of permission classes that are used
for checking if a request passes a certain set of constraints. You can assign a permission
class to your view by setting your View's :attr:`permissions` class attribute.
"""
from django.core.cache import cache
from djangorestframework import status
from djangorestframework.response import ErrorResponse
import time
__all__ = (
'BasePermission',
'FullAnonAccess',
'IsAuthenticated',
'IsAdminUser',
'IsUserOrIsAnonReadOnly',
'PerUserThrottling',
'PerViewThrottling',
'PerResourceThrottling'
)
SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS']
_403_FORBIDDEN_RESPONSE = ErrorResponse(
status.HTTP_403_FORBIDDEN,
{'detail': 'You do not have permission to access this resource. ' +
'You may need to login or otherwise authenticate the request.'})
_503_SERVICE_UNAVAILABLE = ErrorResponse(
status.HTTP_503_SERVICE_UNAVAILABLE,
{'detail': 'request was throttled'})
class BasePermission(object):
"""
A base class from which all permission classes should inherit.
"""
def __init__(self, view):
"""
Permission classes are always passed the current view on creation.
"""
self.view = view
def check_permission(self, auth):
"""
Should simply return, or raise an :exc:`response.ErrorResponse`.
"""
pass
class FullAnonAccess(BasePermission):
"""
Allows full access.
"""
def check_permission(self, user):
pass
class IsAuthenticated(BasePermission):
"""
Allows access only to authenticated users.
"""
def check_permission(self, user):
if not user.is_authenticated():
raise _403_FORBIDDEN_RESPONSE
class IsAdminUser(BasePermission):
"""
Allows access only to admin users.
"""
def check_permission(self, user):
if not user.is_staff:
raise _403_FORBIDDEN_RESPONSE
class IsUserOrIsAnonReadOnly(BasePermission):
"""
The request is authenticated as a user, or is a read-only request.
"""
def check_permission(self, user):
if (not user.is_authenticated() and
self.view.method not in SAFE_METHODS):
raise _403_FORBIDDEN_RESPONSE
class DjangoModelPermissions(BasePermission):
"""
The request is authenticated using `django.contrib.auth` permissions.
See: https://docs.djangoproject.com/en/dev/topics/auth/#permissions
It ensures that the user is authenticated, and has the appropriate
`add`/`change`/`delete` permissions on the model.
This permission should only be used on views with a `ModelResource`.
"""
# Map methods into required permission codes.
# Override this if you need to also provide 'read' permissions,
# or if you want to provide custom permission codes.
perms_map = {
'GET': [],
'OPTIONS': [],
'HEAD': [],
'POST': ['%(app_label)s.add_%(model_name)s'],
'PUT': ['%(app_label)s.change_%(model_name)s'],
'PATCH': ['%(app_label)s.change_%(model_name)s'],
'DELETE': ['%(app_label)s.delete_%(model_name)s'],
}
def get_required_permissions(self, method, model_cls):
"""
Given a model and an HTTP method, return the list of permission
codes that the user is required to have.
"""
kwargs = {
'app_label': model_cls._meta.app_label,
'model_name': model_cls._meta.module_name
}
try:
return [perm % kwargs for perm in self.perms_map[method]]
except KeyError:
ErrorResponse(status.HTTP_405_METHOD_NOT_ALLOWED)
def check_permission(self, user):
method = self.view.method
model_cls = self.view.resource.model
perms = self.get_required_permissions(method, model_cls)
if not user.is_authenticated or not user.has_perms(perms):
raise _403_FORBIDDEN_RESPONSE
class BaseThrottle(BasePermission):
"""
Rate throttling of requests.
The rate (requests / seconds) is set by a :attr:`throttle` attribute
on the :class:`.View` class. The attribute is a string of the form 'number of
requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
Previous request information used for throttling is stored in the cache.
"""
attr_name = 'throttle'
default = '0/sec'
timer = time.time
def get_cache_key(self):
"""
Should return a unique cache-key which can be used for throttling.
Must be overridden.
"""
pass
def check_permission(self, auth):
"""
Check the throttling.
Return `None` or raise an :exc:`.ErrorResponse`.
"""
num, period = getattr(self.view, self.attr_name, self.default).split('/')
self.num_requests = int(num)
self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
self.auth = auth
self.check_throttle()
def check_throttle(self):
"""
Implement the check to see if the request should be throttled.
On success calls :meth:`throttle_success`.
On failure calls :meth:`throttle_failure`.
"""
self.key = self.get_cache_key()
self.history = cache.get(self.key, [])
self.now = self.timer()
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
self.throttle_failure()
else:
self.throttle_success()
def throttle_success(self):
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
self.history.insert(0, self.now)
cache.set(self.key, self.history, self.duration)
header = 'status=SUCCESS; next=%s sec' % self.next()
self.view.add_header('X-Throttle', header)
def throttle_failure(self):
"""
Called when a request to the API has failed due to throttling.
Raises a '503 service unavailable' response.
"""
header = 'status=FAILURE; next=%s sec' % self.next()
self.view.add_header('X-Throttle', header)
raise _503_SERVICE_UNAVAILABLE
def next(self):
"""
Returns the recommended next request time in seconds.
"""
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
return '%.2f' % (remaining_duration / float(available_requests))
class PerUserThrottling(BaseThrottle):
"""
Limits the rate of API calls that may be made by a given user.
The user id will be used as a unique identifier if the user is
authenticated. For anonymous requests, the IP address of the client will
be used.
"""
def get_cache_key(self):
if self.auth.is_authenticated():
ident = self.auth.id
else:
ident = self.view.request.META.get('REMOTE_ADDR', None)
return 'throttle_user_%s' % ident
class PerViewThrottling(BaseThrottle):
"""
Limits the rate of API calls that may be used on a given view.
The class name of the view is used as a unique identifier to
throttle against.
"""
def get_cache_key(self):
return 'throttle_view_%s' % self.view.__class__.__name__
class PerResourceThrottling(BaseThrottle):
"""
Limits the rate of API calls that may be used against all views on
a given resource.
The class name of the resource is used as a unique identifier to
throttle against.
"""
def get_cache_key(self):
return 'throttle_resource_%s' % self.view.resource.__class__.__name__

View File

@@ -0,0 +1,413 @@
"""
Renderers are used to serialize a View's output into specific media types.
Django REST framework also provides HTML and PlainText renderers that help self-document the API,
by serializing the output along with documentation regarding the View, output status and headers,
and providing forms and links depending on the allowed methods, renderers and parsers on the View.
"""
from django import forms
from django.conf import settings
from django.core.serializers.json import DateTimeAwareJSONEncoder
from django.template import RequestContext, loader
from django.utils import simplejson as json
from djangorestframework.compat import yaml
from djangorestframework.utils import dict2xml, url_resolves
from djangorestframework.utils.breadcrumbs import get_breadcrumbs
from djangorestframework.utils.mediatypes import get_media_type_params, add_media_type_param, media_type_matches
from djangorestframework import VERSION
import string
from urllib import quote_plus
__all__ = (
'BaseRenderer',
'TemplateRenderer',
'JSONRenderer',
'JSONPRenderer',
'DocumentingHTMLRenderer',
'DocumentingXHTMLRenderer',
'DocumentingPlainTextRenderer',
'XMLRenderer',
'YAMLRenderer'
)
class BaseRenderer(object):
"""
All renderers must extend this class, set the :attr:`media_type` attribute,
and override the :meth:`render` method.
"""
_FORMAT_QUERY_PARAM = 'format'
media_type = None
format = None
def __init__(self, view):
self.view = view
def can_handle_response(self, accept):
"""
Returns :const:`True` if this renderer is able to deal with the given
*accept* media type.
The default implementation for this function is to check the *accept*
argument against the :attr:`media_type` attribute set on the class to see if
they match.
This may be overridden to provide for other behavior, but typically you'll
instead want to just set the :attr:`media_type` attribute on the class.
"""
format = self.view.kwargs.get(self._FORMAT_QUERY_PARAM, None)
if format is None:
format = self.view.request.GET.get(self._FORMAT_QUERY_PARAM, None)
if format is not None:
return format == self.format
return media_type_matches(self.media_type, accept)
def render(self, obj=None, media_type=None):
"""
Given an object render it into a string.
The requested media type is also passed to this method,
as it may contain parameters relevant to how the parser
should render the output.
EG: ``application/json; indent=4``
By default render simply returns the output as-is.
Override this method to provide for other behavior.
"""
if obj is None:
return ''
return str(obj)
class JSONRenderer(BaseRenderer):
"""
Renderer which serializes to JSON
"""
media_type = 'application/json'
format = 'json'
def render(self, obj=None, media_type=None):
"""
Renders *obj* into serialized JSON.
"""
if obj is None:
return ''
# If the media type looks like 'application/json; indent=4', then
# pretty print the result.
indent = get_media_type_params(media_type).get('indent', None)
sort_keys = False
try:
indent = max(min(int(indent), 8), 0)
sort_keys = True
except (ValueError, TypeError):
indent = None
return json.dumps(obj, cls=DateTimeAwareJSONEncoder, indent=indent, sort_keys=sort_keys)
class JSONPRenderer(JSONRenderer):
"""
Renderer which serializes to JSONP
"""
media_type = 'application/json-p'
format = 'json-p'
renderer_class = JSONRenderer
callback_parameter = 'callback'
def _get_callback(self):
return self.view.request.GET.get(self.callback_parameter, self.callback_parameter)
def _get_renderer(self):
return self.renderer_class(self.view)
def render(self, obj=None, media_type=None):
callback = self._get_callback()
json = self._get_renderer().render(obj, media_type)
return "%s(%s);" % (callback, json)
class XMLRenderer(BaseRenderer):
"""
Renderer which serializes to XML.
"""
media_type = 'application/xml'
format = 'xml'
def render(self, obj=None, media_type=None):
"""
Renders *obj* into serialized XML.
"""
if obj is None:
return ''
return dict2xml(obj)
if yaml:
class YAMLRenderer(BaseRenderer):
"""
Renderer which serializes to YAML.
"""
media_type = 'application/yaml'
format = 'yaml'
def render(self, obj=None, media_type=None):
"""
Renders *obj* into serialized YAML.
"""
if obj is None:
return ''
return yaml.safe_dump(obj)
else:
YAMLRenderer = None
class TemplateRenderer(BaseRenderer):
"""
A Base class provided for convenience.
Render the object simply by using the given template.
To create a template renderer, subclass this class, and set
the :attr:`media_type` and :attr:`template` attributes.
"""
media_type = None
template = None
def render(self, obj=None, media_type=None):
"""
Renders *obj* using the :attr:`template` specified on the class.
"""
if obj is None:
return ''
template = loader.get_template(self.template)
context = RequestContext(self.view.request, {'object': obj})
return template.render(context)
class DocumentingTemplateRenderer(BaseRenderer):
"""
Base class for renderers used to self-document the API.
Implementing classes should extend this class and set the template attribute.
"""
template = None
def _get_content(self, view, request, obj, media_type):
"""
Get the content as if it had been rendered by a non-documenting renderer.
(Typically this will be the content as it would have been if the Resource had been
requested with an 'Accept: */*' header, although with verbose style formatting if appropriate.)
"""
# Find the first valid renderer and render the content. (Don't use another documenting renderer.)
renderers = [renderer for renderer in view.renderers if not issubclass(renderer, DocumentingTemplateRenderer)]
if not renderers:
return '[No renderers were found]'
media_type = add_media_type_param(media_type, 'indent', '4')
content = renderers[0](view).render(obj, media_type)
if not all(char in string.printable for char in content):
return '[%d bytes of binary content]'
return content
def _get_form_instance(self, view, method):
"""
Get a form, possibly bound to either the input or output data.
In the absence on of the Resource having an associated form then
provide a form that can be used to submit arbitrary content.
"""
# Get the form instance if we have one bound to the input
form_instance = None
if method == getattr(view, 'method', view.request.method).lower():
form_instance = getattr(view, 'bound_form_instance', None)
if not form_instance and hasattr(view, 'get_bound_form'):
# Otherwise if we have a response that is valid against the form then use that
if view.response.has_content_body:
try:
form_instance = view.get_bound_form(view.response.cleaned_content, method=method)
if form_instance and not form_instance.is_valid():
form_instance = None
except Exception:
form_instance = None
# If we still don't have a form instance then try to get an unbound form
if not form_instance:
try:
form_instance = view.get_bound_form(method=method)
except Exception:
pass
# If we still don't have a form instance then try to get an unbound form which can tunnel arbitrary content types
if not form_instance:
form_instance = self._get_generic_content_form(view)
return form_instance
def _get_generic_content_form(self, view):
"""
Returns a form that allows for arbitrary content types to be tunneled via standard HTML forms
(Which are typically application/x-www-form-urlencoded)
"""
# If we're not using content overloading there's no point in supplying a generic form,
# as the view won't treat the form's value as the content of the request.
if not getattr(view, '_USE_FORM_OVERLOADING', False):
return None
# NB. http://jacobian.org/writing/dynamic-form-generation/
class GenericContentForm(forms.Form):
def __init__(self, view):
"""We don't know the names of the fields we want to set until the point the form is instantiated,
as they are determined by the Resource the form is being created against.
Add the fields dynamically."""
super(GenericContentForm, self).__init__()
contenttype_choices = [(media_type, media_type) for media_type in view._parsed_media_types]
initial_contenttype = view._default_parser.media_type
self.fields[view._CONTENTTYPE_PARAM] = forms.ChoiceField(label='Content Type',
choices=contenttype_choices,
initial=initial_contenttype)
self.fields[view._CONTENT_PARAM] = forms.CharField(label='Content',
widget=forms.Textarea)
# If either of these reserved parameters are turned off then content tunneling is not possible
if self.view._CONTENTTYPE_PARAM is None or self.view._CONTENT_PARAM is None:
return None
# Okey doke, let's do it
return GenericContentForm(view)
def get_name(self):
try:
return self.view.get_name()
except AttributeError:
return self.view.__doc__
def get_description(self, html=None):
if html is None:
html = bool('html' in self.format)
try:
return self.view.get_description(html)
except AttributeError:
return self.view.__doc__
def render(self, obj=None, media_type=None):
"""
Renders *obj* using the :attr:`template` set on the class.
The context used in the template contains all the information
needed to self-document the response to this request.
"""
content = self._get_content(self.view, self.view.request, obj, media_type)
put_form_instance = self._get_form_instance(self.view, 'put')
post_form_instance = self._get_form_instance(self.view, 'post')
if url_resolves(settings.LOGIN_URL) and url_resolves(settings.LOGOUT_URL):
login_url = "%s?next=%s" % (settings.LOGIN_URL, quote_plus(self.view.request.path))
logout_url = "%s?next=%s" % (settings.LOGOUT_URL, quote_plus(self.view.request.path))
else:
login_url = None
logout_url = None
name = self.get_name()
description = self.get_description()
breadcrumb_list = get_breadcrumbs(self.view.request.path)
template = loader.get_template(self.template)
context = RequestContext(self.view.request, {
'content': content,
'view': self.view,
'request': self.view.request, # TODO: remove
'response': self.view.response,
'description': description,
'name': name,
'version': VERSION,
'breadcrumblist': breadcrumb_list,
'available_formats': self.view._rendered_formats,
'put_form': put_form_instance,
'post_form': post_form_instance,
'login_url': login_url,
'logout_url': logout_url,
'FORMAT_PARAM': self._FORMAT_QUERY_PARAM,
'METHOD_PARAM': getattr(self.view, '_METHOD_PARAM', None),
'ADMIN_MEDIA_PREFIX': getattr(settings, 'ADMIN_MEDIA_PREFIX', None),
})
ret = template.render(context)
# Munge DELETE Response code to allow us to return content
# (Do this *after* we've rendered the template so that we include
# the normal deletion response code in the output)
if self.view.response.status == 204:
self.view.response.status = 200
return ret
class DocumentingHTMLRenderer(DocumentingTemplateRenderer):
"""
Renderer which provides a browsable HTML interface for an API.
See the examples at http://api.django-rest-framework.org to see this in action.
"""
media_type = 'text/html'
format = 'html'
template = 'djangorestframework/api.html'
class DocumentingXHTMLRenderer(DocumentingTemplateRenderer):
"""
Identical to DocumentingHTMLRenderer, except with an xhtml media type.
We need this to be listed in preference to xml in order to return HTML to WebKit based browsers,
given their Accept headers.
"""
media_type = 'application/xhtml+xml'
format = 'xhtml'
template = 'djangorestframework/api.html'
class DocumentingPlainTextRenderer(DocumentingTemplateRenderer):
"""
Renderer that serializes the object with the default renderer, but also provides plain-text
documentation of the returned status and headers, and of the resource's name and description.
Useful for browsing an API with command line tools.
"""
media_type = 'text/plain'
format = 'txt'
template = 'djangorestframework/api.txt'
DEFAULT_RENDERERS = (
JSONRenderer,
JSONPRenderer,
DocumentingHTMLRenderer,
DocumentingXHTMLRenderer,
DocumentingPlainTextRenderer,
XMLRenderer
)
if YAMLRenderer:
DEFAULT_RENDERERS += (YAMLRenderer,)

View File

@@ -0,0 +1,386 @@
from django import forms
from django.core.urlresolvers import reverse, get_urlconf, get_resolver, NoReverseMatch
from django.db import models
from djangorestframework.response import ErrorResponse
from djangorestframework.serializer import Serializer, _SkipField
from djangorestframework.utils import as_tuple
class BaseResource(Serializer):
"""
Base class for all Resource classes, which simply defines the interface they provide.
"""
fields = None
include = None
exclude = None
def __init__(self, view=None, depth=None, stack=[], **kwargs):
super(BaseResource, self).__init__(depth, stack, **kwargs)
self.view = view
def validate_request(self, data, files=None):
"""
Given the request content return the cleaned, validated content.
Typically raises a :exc:`response.ErrorResponse` with status code 400 (Bad Request) on failure.
"""
return data
def filter_response(self, obj):
"""
Given the response content, filter it into a serializable object.
"""
return self.serialize(obj)
class Resource(BaseResource):
"""
A Resource determines how a python object maps to some serializable data.
Objects that a resource can act on include plain Python object instances, Django Models, and Django QuerySets.
"""
# The model attribute refers to the Django Model which this Resource maps to.
# (The Model's class, rather than an instance of the Model)
model = None
# By default the set of returned fields will be the set of:
#
# 0. All the fields on the model, excluding 'id'.
# 1. All the properties on the model.
# 2. The absolute_url of the model, if a get_absolute_url method exists for the model.
#
# If you wish to override this behaviour,
# you should explicitly set the fields attribute on your class.
fields = None
class FormResource(Resource):
"""
Resource class that uses forms for validation.
Also provides a :meth:`get_bound_form` method which may be used by some renderers.
On calling :meth:`validate_request` this validator may set a :attr:`bound_form_instance` attribute on the
view, which may be used by some renderers.
"""
form = None
"""
The :class:`Form` class that should be used for request validation.
This can be overridden by a :attr:`form` attribute on the :class:`views.View`.
"""
allow_unknown_form_fields = False
"""
Flag to check for unknown fields when validating a form. If set to false and
we receive request data that is not expected by the form it raises an
:exc:`response.ErrorResponse` with status code 400. If set to true, only
expected fields are validated.
"""
def validate_request(self, data, files=None):
"""
Given some content as input return some cleaned, validated content.
Raises a :exc:`response.ErrorResponse` with status code 400 (Bad Request) on failure.
Validation is standard form validation, with an additional constraint that *no extra unknown fields* may be supplied
if :attr:`self.allow_unknown_form_fields` is ``False``.
On failure the :exc:`response.ErrorResponse` content is a dict which may contain :obj:`'errors'` and :obj:`'field-errors'` keys.
If the :obj:`'errors'` key exists it is a list of strings of non-field errors.
If the :obj:`'field-errors'` key exists it is a dict of ``{'field name as string': ['errors as strings', ...]}``.
"""
return self._validate(data, files)
def _validate(self, data, files, allowed_extra_fields=(), fake_data=None):
"""
Wrapped by validate to hide the extra flags that are used in the implementation.
allowed_extra_fields is a list of fields which are not defined by the form, but which we still
expect to see on the input.
fake_data is a string that should be used as an extra key, as a kludge to force .errors
to be populated when an empty dict is supplied in `data`
"""
# We'd like nice error messages even if no content is supplied.
# Typically if an empty dict is given to a form Django will
# return .is_valid() == False, but .errors == {}
#
# To get around this case we revalidate with some fake data.
if fake_data:
data[fake_data] = '_fake_data'
allowed_extra_fields = tuple(allowed_extra_fields) + ('_fake_data',)
bound_form = self.get_bound_form(data, files)
if bound_form is None:
return data
self.view.bound_form_instance = bound_form
data = data and data or {}
files = files and files or {}
seen_fields_set = set(data.keys())
form_fields_set = set(bound_form.fields.keys())
allowed_extra_fields_set = set(allowed_extra_fields)
# In addition to regular validation we also ensure no additional fields are being passed in...
unknown_fields = seen_fields_set - (form_fields_set | allowed_extra_fields_set)
unknown_fields = unknown_fields - set(('csrfmiddlewaretoken', '_accept', '_method')) # TODO: Ugh.
# Check using both regular validation, and our stricter no additional fields rule
if bound_form.is_valid() and (self.allow_unknown_form_fields or not unknown_fields):
# Validation succeeded...
cleaned_data = bound_form.cleaned_data
# Add in any extra fields to the cleaned content...
for key in (allowed_extra_fields_set & seen_fields_set) - set(cleaned_data.keys()):
cleaned_data[key] = data[key]
return cleaned_data
# Validation failed...
detail = {}
if not bound_form.errors and not unknown_fields:
# is_valid() was False, but errors was empty.
# If we havn't already done so attempt revalidation with some fake data
# to force django to give us an errors dict.
if fake_data is None:
return self._validate(data, files, allowed_extra_fields, '_fake_data')
# If we've already set fake_dict and we're still here, fallback gracefully.
detail = {u'errors': [u'No content was supplied.']}
else:
# Add any non-field errors
if bound_form.non_field_errors():
detail[u'errors'] = bound_form.non_field_errors()
# Add standard field errors
field_errors = dict(
(key, map(unicode, val))
for (key, val)
in bound_form.errors.iteritems()
if not key.startswith('__')
)
# Add any unknown field errors
for key in unknown_fields:
field_errors[key] = [u'This field does not exist.']
if field_errors:
detail[u'field_errors'] = field_errors
# Return HTTP 400 response (BAD REQUEST)
raise ErrorResponse(400, detail)
def get_form_class(self, method=None):
"""
Returns the form class used to validate this resource.
"""
# A form on the view overrides a form on the resource.
form = getattr(self.view, 'form', None) or self.form
# Use the requested method or determine the request method
if method is None and hasattr(self.view, 'request') and hasattr(self.view, 'method'):
method = self.view.method
elif method is None and hasattr(self.view, 'request'):
method = self.view.request.method
# A method form on the view or resource overrides the general case.
# Method forms are attributes like `get_form` `post_form` `put_form`.
if method:
form = getattr(self, '%s_form' % method.lower(), form)
form = getattr(self.view, '%s_form' % method.lower(), form)
return form
def get_bound_form(self, data=None, files=None, method=None):
"""
Given some content return a Django form bound to that content.
If form validation is turned off (:attr:`form` class attribute is :const:`None`) then returns :const:`None`.
"""
form = self.get_form_class(method)
if not form:
return None
if data is not None or files is not None:
return form(data, files)
return form()
class ModelResource(FormResource):
"""
Resource class that uses forms for validation and otherwise falls back to a model form if no form is set.
Also provides a :meth:`get_bound_form` method which may be used by some renderers.
"""
# Auto-register new ModelResource classes into _model_to_resource
#__metaclass__ = _RegisterModelResource
form = None
"""
The form class that should be used for request validation.
If set to :const:`None` then the default model form validation will be used.
This can be overridden by a :attr:`form` attribute on the :class:`views.View`.
"""
model = None
"""
The model class which this resource maps to.
This can be overridden by a :attr:`model` attribute on the :class:`views.View`.
"""
fields = None
"""
The list of fields to use on the output.
May be any of:
The name of a model field. To view nested resources, give the field as a tuple of ("fieldName", resource) where `resource` may be any of ModelResource reference, the name of a ModelResourc reference as a string or a tuple of strings representing fields on the nested model.
The name of an attribute on the model.
The name of an attribute on the resource.
The name of a method on the model, with a signature like ``func(self)``.
The name of a method on the resource, with a signature like ``func(self, instance)``.
"""
exclude = ('id', 'pk')
"""
The list of fields to exclude. This is only used if :attr:`fields` is not set.
"""
include = ('url',)
"""
The list of extra fields to include. This is only used if :attr:`fields` is not set.
"""
def __init__(self, view=None, depth=None, stack=[], **kwargs):
"""
Allow :attr:`form` and :attr:`model` attributes set on the
:class:`View` to override the :attr:`form` and :attr:`model`
attributes set on the :class:`Resource`.
"""
super(ModelResource, self).__init__(view, depth, stack, **kwargs)
self.model = getattr(view, 'model', None) or self.model
def validate_request(self, data, files=None):
"""
Given some content as input return some cleaned, validated content.
Raises a :exc:`response.ErrorResponse` with status code 400 (Bad Request) on failure.
Validation is standard form or model form validation,
with an additional constraint that no extra unknown fields may be supplied,
and that all fields specified by the fields class attribute must be supplied,
even if they are not validated by the form/model form.
On failure the ErrorResponse content is a dict which may contain :obj:`'errors'` and :obj:`'field-errors'` keys.
If the :obj:`'errors'` key exists it is a list of strings of non-field errors.
If the ''field-errors'` key exists it is a dict of {field name as string: list of errors as strings}.
"""
return self._validate(data, files, allowed_extra_fields=self._property_fields_set)
def get_bound_form(self, data=None, files=None, method=None):
"""
Given some content return a ``Form`` instance bound to that content.
If the :attr:`form` class attribute has been explicitly set then that class will be used
to create the Form, otherwise the model will be used to create a ModelForm.
"""
form = self.get_form_class(method)
if not form and self.model:
# Fall back to ModelForm which we create on the fly
class OnTheFlyModelForm(forms.ModelForm):
class Meta:
model = self.model
#fields = tuple(self._model_fields_set)
form = OnTheFlyModelForm
# Both form and model not set? Okay bruv, whatevs...
if not form:
return None
# Instantiate the ModelForm as appropriate
if data is not None or files is not None:
if issubclass(form, forms.ModelForm) and hasattr(self.view, 'model_instance'):
# Bound to an existing model instance
return form(data, files, instance=self.view.model_instance)
else:
return form(data, files)
return form()
def url(self, instance):
"""
Attempts to reverse resolve the url of the given model *instance* for this resource.
Requires a ``View`` with :class:`mixins.InstanceMixin` to have been created for this resource.
This method can be overridden if you need to set the resource url reversing explicitly.
"""
if not hasattr(self, 'view_callable'):
raise _SkipField
# dis does teh magicks...
urlconf = get_urlconf()
resolver = get_resolver(urlconf)
possibilities = resolver.reverse_dict.getlist(self.view_callable[0])
for tuple_item in possibilities:
possibility = tuple_item[0]
# pattern = tuple_item[1]
# Note: defaults = tuple_item[2] for django >= 1.3
for result, params in possibility:
#instance_attrs = dict([ (param, getattr(instance, param)) for param in params if hasattr(instance, param) ])
instance_attrs = {}
for param in params:
if not hasattr(instance, param):
continue
attr = getattr(instance, param)
if isinstance(attr, models.Model):
instance_attrs[param] = attr.pk
else:
instance_attrs[param] = attr
try:
return reverse(self.view_callable[0], kwargs=instance_attrs)
except NoReverseMatch:
pass
raise _SkipField
@property
def _model_fields_set(self):
"""
Return a set containing the names of validated fields on the model.
"""
model_fields = set(field.name for field in self.model._meta.fields)
if self.fields:
return model_fields & set(as_tuple(self.fields))
return model_fields - set(as_tuple(self.exclude))
@property
def _property_fields_set(self):
"""
Returns a set containing the names of validated properties on the model.
"""
property_fields = set(attr for attr in dir(self.model) if
isinstance(getattr(self.model, attr, None), property)
and not attr.startswith('_'))
if self.fields:
return property_fields & set(as_tuple(self.fields))
return property_fields.union(set(as_tuple(self.include))) - set(as_tuple(self.exclude))

View File

@@ -0,0 +1,44 @@
"""
The :mod:`response` module provides Response classes you can use in your
views to return a certain HTTP response. Typically a response is *rendered*
into a HTTP response depending on what renderers are set on your view and
als depending on the accept header of the request.
"""
from django.core.handlers.wsgi import STATUS_CODE_TEXT
__all__ = ('Response', 'ErrorResponse')
# TODO: remove raw_content/cleaned_content and just use content?
class Response(object):
"""
An HttpResponse that may include content that hasn't yet been serialized.
"""
def __init__(self, status=200, content=None, headers=None):
self.status = status
self.media_type = None
self.has_content_body = content is not None
self.raw_content = content # content prior to filtering
self.cleaned_content = content # content after filtering
self.headers = headers or {}
@property
def status_text(self):
"""
Return reason text corresponding to our HTTP response status code.
Provided for convenience.
"""
return STATUS_CODE_TEXT.get(self.status, '')
class ErrorResponse(Exception):
"""
An exception representing an Response that should be returned immediately.
Any content should be serialized as-is, without being filtered.
"""
def __init__(self, status, content=None, headers={}):
self.response = Response(status, content=content, headers=headers)

View File

@@ -0,0 +1,64 @@
"""
Useful tool to run the test suite for djangorestframework and generate a coverage report.
"""
# http://ericholscher.com/blog/2009/jun/29/enable-setuppy-test-your-django-apps/
# http://www.travisswicegood.com/2010/01/17/django-virtualenv-pip-and-fabric/
# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
import os
import sys
os.environ['DJANGO_SETTINGS_MODULE'] = 'djangorestframework.runtests.settings'
from coverage import coverage
from itertools import chain
def main():
"""Run the tests for djangorestframework and generate a coverage report."""
cov = coverage()
cov.erase()
cov.start()
from django.conf import settings
from django.test.utils import get_runner
TestRunner = get_runner(settings)
if hasattr(TestRunner, 'func_name'):
# Pre 1.2 test runners were just functions,
# and did not support the 'failfast' option.
import warnings
warnings.warn(
'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.',
DeprecationWarning
)
failures = TestRunner(['djangorestframework'])
else:
test_runner = TestRunner()
failures = test_runner.run_tests(['djangorestframework'])
cov.stop()
# Discover the list of all modules that we should test coverage for
import djangorestframework
project_dir = os.path.dirname(djangorestframework.__file__)
cov_files = []
for (path, dirs, files) in os.walk(project_dir):
# Drop tests and runtests directories from the test coverage report
if os.path.basename(path) == 'tests' or os.path.basename(path) == 'runtests':
continue
# Drop the compat module from coverage, since we're not interested in the coverage
# of a module which is specifically for resolving environment dependant imports.
# (Because we'll end up getting different coverage reports for it for each environment)
if 'compat.py' in files:
files.remove('compat.py')
cov_files.extend([os.path.join(path, file) for file in files if file.endswith('.py')])
cov.report(cov_files)
cov.xml_report(cov_files)
sys.exit(failures)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,41 @@
'''
Created on Mar 10, 2011
@author: tomchristie
'''
# http://ericholscher.com/blog/2009/jun/29/enable-setuppy-test-your-django-apps/
# http://www.travisswicegood.com/2010/01/17/django-virtualenv-pip-and-fabric/
# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
import os
import sys
os.environ['DJANGO_SETTINGS_MODULE'] = 'djangorestframework.runtests.settings'
from django.conf import settings
from django.test.utils import get_runner
def usage():
return """
Usage: python runtests.py [UnitTestClass].[method]
You can pass the Class name of the `UnitTestClass` you want to test.
Append a method name if you only want to test a specific method of that class.
"""
def main():
TestRunner = get_runner(settings)
test_runner = TestRunner()
if len(sys.argv) == 2:
test_case = '.' + sys.argv[1]
elif len(sys.argv) == 1:
test_case = ''
else:
print usage()
sys.exit(1)
failures = test_runner.run_tests(['djangorestframework' + test_case])
sys.exit(failures)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,122 @@
# Django settings for testproject project.
DEBUG = True
TEMPLATE_DEBUG = DEBUG
DEBUG_PROPAGATE_EXCEPTIONS = True
ADMINS = (
# ('Your Name', 'your_email@domain.com'),
)
MANAGERS = ADMINS
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'.
'NAME': 'sqlite.db', # Or path to database file if using sqlite3.
'USER': '', # Not used with sqlite3.
'PASSWORD': '', # Not used with sqlite3.
'HOST': '', # Set to empty string for localhost. Not used with sqlite3.
'PORT': '', # Set to empty string for default. Not used with sqlite3.
}
}
# Local time zone for this installation. Choices can be found here:
# http://en.wikipedia.org/wiki/List_of_tz_zones_by_name
# although not all choices may be available on all operating systems.
# On Unix systems, a value of None will cause Django to use the same
# timezone as the operating system.
# If running in a Windows environment this must be set to the same as your
# system time zone.
TIME_ZONE = 'Europe/London'
# Language code for this installation. All choices can be found here:
# http://www.i18nguy.com/unicode/language-identifiers.html
LANGUAGE_CODE = 'en-uk'
SITE_ID = 1
# If you set this to False, Django will make some optimizations so as not
# to load the internationalization machinery.
USE_I18N = True
# If you set this to False, Django will not format dates, numbers and
# calendars according to the current locale
USE_L10N = True
# Absolute filesystem path to the directory that will hold user-uploaded files.
# Example: "/home/media/media.lawrence.com/"
MEDIA_ROOT = ''
# URL that handles the media served from MEDIA_ROOT. Make sure to use a
# trailing slash if there is a path component (optional in other cases).
# Examples: "http://media.lawrence.com", "http://example.com/media/"
MEDIA_URL = ''
# URL prefix for admin media -- CSS, JavaScript and images. Make sure to use a
# trailing slash.
# Examples: "http://foo.com/media/", "/media/".
ADMIN_MEDIA_PREFIX = '/media/'
# Make this unique, and don't share it with anybody.
SECRET_KEY = 'u@x-aj9(hoh#rb-^ymf#g2jx_hp0vj7u5#b@ag1n^seu9e!%cy'
# List of callables that know how to import templates from various sources.
TEMPLATE_LOADERS = (
'django.template.loaders.filesystem.Loader',
'django.template.loaders.app_directories.Loader',
# 'django.template.loaders.eggs.Loader',
)
MIDDLEWARE_CLASSES = (
'django.middleware.common.CommonMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
)
ROOT_URLCONF = 'urls'
TEMPLATE_DIRS = (
# Put strings here, like "/home/html/django_templates" or "C:/www/django/templates".
# Always use forward slashes, even on Windows.
# Don't forget to use absolute paths, not relative paths.
)
INSTALLED_APPS = (
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.sites',
'django.contrib.messages',
# Uncomment the next line to enable the admin:
# 'django.contrib.admin',
# Uncomment the next line to enable admin documentation:
# 'django.contrib.admindocs',
'djangorestframework',
)
STATIC_URL = '/static/'
import django
if django.VERSION < (1, 3):
INSTALLED_APPS += ('staticfiles',)
# OAuth support is optional, so we only test oauth if it's installed.
try:
import oauth_provider
except ImportError:
pass
else:
INSTALLED_APPS += ('oauth_provider',)
# If we're running on the Jenkins server we want to archive the coverage reports as XML.
import os
if os.environ.get('HUDSON_URL', None):
TEST_RUNNER = 'xmlrunner.extra.djangotestrunner.XMLTestRunner'
TEST_OUTPUT_VERBOSE = True
TEST_OUTPUT_DESCRIPTIONS = True
TEST_OUTPUT_DIR = 'xmlrunner'

View File

@@ -0,0 +1,7 @@
"""
Blank URLConf just to keep runtests.py happy.
"""
from django.conf.urls.defaults import *
urlpatterns = patterns('',
)

View File

@@ -0,0 +1,298 @@
"""
Customizable serialization.
"""
from django.db import models
from django.db.models.query import QuerySet
from django.utils.encoding import smart_unicode, is_protected_type, smart_str
import inspect
import types
# We register serializer classes, so that we can refer to them by their
# class names, if there are cyclical serialization heirachys.
_serializers = {}
def _field_to_tuple(field):
"""
Convert an item in the `fields` attribute into a 2-tuple.
"""
if isinstance(field, (tuple, list)):
return (field[0], field[1])
return (field, None)
def _fields_to_list(fields):
"""
Return a list of field names.
"""
return [_field_to_tuple(field)[0] for field in fields or ()]
def _fields_to_dict(fields):
"""
Return a `dict` of field name -> None, or tuple of fields, or Serializer class
"""
return dict([_field_to_tuple(field) for field in fields or ()])
class _SkipField(Exception):
"""
Signals that a serialized field should be ignored.
We use this mechanism as the default behavior for ensuring
that we don't infinitely recurse when dealing with nested data.
"""
pass
class _RegisterSerializer(type):
"""
Metaclass to register serializers.
"""
def __new__(cls, name, bases, attrs):
# Build the class and register it.
ret = super(_RegisterSerializer, cls).__new__(cls, name, bases, attrs)
_serializers[name] = ret
return ret
class Serializer(object):
"""
Converts python objects into plain old native types suitable for
serialization. In particular it handles models and querysets.
The output format is specified by setting a number of attributes
on the class.
You may also override any of the serialization methods, to provide
for more flexible behavior.
Valid output types include anything that may be directly rendered into
json, xml etc...
"""
__metaclass__ = _RegisterSerializer
fields = ()
"""
Specify the fields to be serialized on a model or dict.
Overrides `include` and `exclude`.
"""
include = ()
"""
Fields to add to the default set to be serialized on a model/dict.
"""
exclude = ()
"""
Fields to remove from the default set to be serialized on a model/dict.
"""
rename = {}
"""
A dict of key->name to use for the field keys.
"""
related_serializer = None
"""
The default serializer class to use for any related models.
"""
depth = None
"""
The maximum depth to serialize to, or `None`.
"""
def __init__(self, depth=None, stack=[], **kwargs):
if depth is not None:
self.depth = depth
self.stack = stack
def get_fields(self, obj):
"""
Return the set of field names/keys to use for a model instance/dict.
"""
fields = self.fields
# If `fields` is not set, we use the default fields and modify
# them with `include` and `exclude`
if not fields:
default = self.get_default_fields(obj)
include = self.include or ()
exclude = self.exclude or ()
fields = set(default + list(include)) - set(exclude)
else:
fields = _fields_to_list(self.fields)
return fields
def get_default_fields(self, obj):
"""
Return the default list of field names/keys for a model instance/dict.
These are used if `fields` is not given.
"""
if isinstance(obj, models.Model):
opts = obj._meta
return [field.name for field in opts.fields + opts.many_to_many]
else:
return obj.keys()
def get_related_serializer(self, key):
info = _fields_to_dict(self.fields).get(key, None)
# If an element in `fields` is a 2-tuple of (str, tuple)
# then the second element of the tuple is the fields to
# set on the related serializer
if isinstance(info, (list, tuple)):
class OnTheFlySerializer(Serializer):
fields = info
return OnTheFlySerializer
# If an element in `fields` is a 2-tuple of (str, Serializer)
# then the second element of the tuple is the Serializer
# class to use for that field.
elif isinstance(info, type) and issubclass(info, Serializer):
return info
# If an element in `fields` is a 2-tuple of (str, str)
# then the second element of the tuple is the name of the Serializer
# class to use for that field.
#
# Black magic to deal with cyclical Serializer dependancies.
# Similar to what Django does for cyclically related models.
elif isinstance(info, str) and info in _serializers:
return _serializers[info]
# Otherwise use `related_serializer` or fall back to `Serializer`
return getattr(self, 'related_serializer') or Serializer
def serialize_key(self, key):
"""
Keys serialize to their string value,
unless they exist in the `rename` dict.
"""
return self.rename.get(smart_str(key), smart_str(key))
def serialize_val(self, key, obj):
"""
Convert a model field or dict value into a serializable representation.
"""
related_serializer = self.get_related_serializer(key)
if self.depth is None:
depth = None
elif self.depth <= 0:
return self.serialize_max_depth(obj)
else:
depth = self.depth - 1
if any([obj is elem for elem in self.stack]):
return self.serialize_recursion(obj)
else:
stack = self.stack[:]
stack.append(obj)
return related_serializer(depth=depth, stack=stack).serialize(obj)
def serialize_max_depth(self, obj):
"""
Determine how objects should be serialized once `depth` is exceeded.
The default behavior is to ignore the field.
"""
raise _SkipField
def serialize_recursion(self, obj):
"""
Determine how objects should be serialized if recursion occurs.
The default behavior is to ignore the field.
"""
raise _SkipField
def serialize_model(self, instance):
"""
Given a model instance or dict, serialize it to a dict..
"""
data = {}
fields = self.get_fields(instance)
# serialize each required field
for fname in fields:
try:
# we first check for a method 'fname' on self,
# 'fname's signature must be 'def fname(self, instance)'
meth = getattr(self, fname, None)
if (inspect.ismethod(meth) and
len(inspect.getargspec(meth)[0]) == 2):
obj = meth(instance)
elif hasattr(instance, '__contains__') and fname in instance:
# then check for a key 'fname' on the instance
obj = instance[fname]
elif hasattr(instance, smart_str(fname)):
# finally check for an attribute 'fname' on the instance
obj = getattr(instance, fname)
else:
continue
key = self.serialize_key(fname)
val = self.serialize_val(fname, obj)
data[key] = val
except _SkipField:
pass
return data
def serialize_iter(self, obj):
"""
Convert iterables into a serializable representation.
"""
return [self.serialize(item) for item in obj]
def serialize_func(self, obj):
"""
Convert no-arg methods and functions into a serializable representation.
"""
return self.serialize(obj())
def serialize_manager(self, obj):
"""
Convert a model manager into a serializable representation.
"""
return self.serialize_iter(obj.all())
def serialize_fallback(self, obj):
"""
Convert any unhandled object into a serializable representation.
"""
return smart_unicode(obj, strings_only=True)
def serialize(self, obj):
"""
Convert any object into a serializable representation.
"""
if isinstance(obj, (dict, models.Model)):
# Model instances & dictionaries
return self.serialize_model(obj)
elif isinstance(obj, (tuple, list, set, QuerySet, types.GeneratorType)):
# basic iterables
return self.serialize_iter(obj)
elif isinstance(obj, models.Manager):
# Manager objects
return self.serialize_manager(obj)
elif inspect.isfunction(obj) and not inspect.getargspec(obj)[0]:
# function with no args
return self.serialize_func(obj)
elif inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1:
# bound method
return self.serialize_func(obj)
# Protected types are passed through as is.
# (i.e. Primitives like None, numbers, dates, and Decimals.)
if is_protected_type(obj):
return obj
# All other values are converted to string.
return self.serialize_fallback(obj)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,48 @@
"""
Descriptive HTTP status codes, for code readability.
See RFC 2616 - Sec 10: http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html
Also see django.core.handlers.wsgi.STATUS_CODE_TEXT
"""
HTTP_100_CONTINUE = 100
HTTP_101_SWITCHING_PROTOCOLS = 101
HTTP_200_OK = 200
HTTP_201_CREATED = 201
HTTP_202_ACCEPTED = 202
HTTP_203_NON_AUTHORITATIVE_INFORMATION = 203
HTTP_204_NO_CONTENT = 204
HTTP_205_RESET_CONTENT = 205
HTTP_206_PARTIAL_CONTENT = 206
HTTP_300_MULTIPLE_CHOICES = 300
HTTP_301_MOVED_PERMANENTLY = 301
HTTP_302_FOUND = 302
HTTP_303_SEE_OTHER = 303
HTTP_304_NOT_MODIFIED = 304
HTTP_305_USE_PROXY = 305
HTTP_306_RESERVED = 306
HTTP_307_TEMPORARY_REDIRECT = 307
HTTP_400_BAD_REQUEST = 400
HTTP_401_UNAUTHORIZED = 401
HTTP_402_PAYMENT_REQUIRED = 402
HTTP_403_FORBIDDEN = 403
HTTP_404_NOT_FOUND = 404
HTTP_405_METHOD_NOT_ALLOWED = 405
HTTP_406_NOT_ACCEPTABLE = 406
HTTP_407_PROXY_AUTHENTICATION_REQUIRED = 407
HTTP_408_REQUEST_TIMEOUT = 408
HTTP_409_CONFLICT = 409
HTTP_410_GONE = 410
HTTP_411_LENGTH_REQUIRED = 411
HTTP_412_PRECONDITION_FAILED = 412
HTTP_413_REQUEST_ENTITY_TOO_LARGE = 413
HTTP_414_REQUEST_URI_TOO_LONG = 414
HTTP_415_UNSUPPORTED_MEDIA_TYPE = 415
HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE = 416
HTTP_417_EXPECTATION_FAILED = 417
HTTP_500_INTERNAL_SERVER_ERROR = 500
HTTP_501_NOT_IMPLEMENTED = 501
HTTP_502_BAD_GATEWAY = 502
HTTP_503_SERVICE_UNAVAILABLE = 503
HTTP_504_GATEWAY_TIMEOUT = 504
HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505

View File

@@ -0,0 +1,3 @@
{% extends "djangorestframework/base.html" %}
{# Override this template in your own templates directory to customize #}

View File

@@ -0,0 +1,8 @@
{% autoescape off %}{{ name }}
{{ description }}
HTTP/1.0 {{ response.status }} {{ response.status_text }}
{% for key, val in response.headers.items %}{{ key }}: {{ val }}
{% endfor %}
{{ content }}{% endautoescape %}

View File

@@ -0,0 +1,142 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
{% load urlize_quoted_links %}
{% load add_query_param %}
{% load static %}
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<link rel="stylesheet" type="text/css" href='{% get_static_prefix %}djangorestframework/css/style.css'/>
{% block extrastyle %}{% endblock %}
<title>{% block title %}Django REST framework - {{ name }}{% endblock %}</title>
{% block extrahead %}{% endblock %}
{% block blockbots %}<meta name="robots" content="NONE,NOARCHIVE" />{% endblock %}
</head>
<body class="{% block bodyclass %}{% endblock %}">
<div id="container">
<div id="header">
<div id="branding">
<h1 id="site-name">{% block branding %}<a href='http://django-rest-framework.org'>Django REST framework</a> <span class="version"> v {{ version }}</span>{% endblock %}</h1>
</div>
<div id="user-tools">
{% if user.is_active %}Welcome, {{ user }}.{% if logout_url %} <a href='{{ logout_url }}'>Log out</a>{% endif %}{% else %}Anonymous {% if login_url %}<a href='{{ login_url }}'>Log in</a>{% endif %}{% endif %}
{% block userlinks %}{% endblock %}
</div>
{% block nav-global %}{% endblock %}
</div>
<div class="breadcrumbs">
{% block breadcrumbs %}
{% for breadcrumb_name, breadcrumb_url in breadcrumblist %}
<a href="{{ breadcrumb_url }}">{{ breadcrumb_name }}</a> {% if not forloop.last %}&rsaquo;{% endif %}
{% endfor %}
{% endblock %}
</div>
<!-- Content -->
<div id="content" class="{% block coltype %}colM{% endblock %}">
{% if 'OPTIONS' in view.allowed_methods %}
<form action="{{ request.get_full_path }}" method="post">
{% csrf_token %}
<input type="hidden" name="{{ METHOD_PARAM }}" value="OPTIONS" />
<input type="submit" value="OPTIONS" class="default" />
</form>
{% endif %}
<div class='content-main'>
<h1>{{ name }}</h1>
<p>{{ description }}</p>
<div class='module'>
<pre><b>{{ response.status }} {{ response.status_text }}</b>{% autoescape off %}
{% for key, val in response.headers.items %}<b>{{ key }}:</b> {{ val|urlize_quoted_links }}
{% endfor %}
{{ content|urlize_quoted_links }}</pre>{% endautoescape %}</div>
{% if 'GET' in view.allowed_methods %}
<form>
<fieldset class='module aligned'>
<h2>GET {{ name }}</h2>
<div class='submit-row' style='margin: 0; border: 0'>
<a href='{{ request.get_full_path }}' rel="nofollow" style='float: left'>GET</a>
{% for format in available_formats %}
{% with FORMAT_PARAM|add:"="|add:format as param %}
[<a href='{{ request.get_full_path|add_query_param:param }}' rel="nofollow">{{ format }}</a>]
{% endwith %}
{% endfor %}
</div>
</fieldset>
</form>
{% endif %}
{# Only display the POST/PUT/DELETE forms if method tunneling via POST forms is enabled and the user has permissions on this view. #}
{% if METHOD_PARAM and response.status != 403 %}
{% if 'POST' in view.allowed_methods %}
<form action="{{ request.get_full_path }}" method="post" {% if post_form.is_multipart %}enctype="multipart/form-data"{% endif %}>
<fieldset class='module aligned'>
<h2>POST {{ name }}</h2>
{% csrf_token %}
{{ post_form.non_field_errors }}
{% for field in post_form %}
<div class='form-row'>
{{ field.label_tag }}
{{ field }}
<span class='help'>{{ field.help_text }}</span>
{{ field.errors }}
</div>
{% endfor %}
<div class='submit-row' style='margin: 0; border: 0'>
<input type="submit" value="POST" class="default" />
</div>
</fieldset>
</form>
{% endif %}
{% if 'PUT' in view.allowed_methods %}
<form action="{{ request.get_full_path }}" method="post" {% if put_form.is_multipart %}enctype="multipart/form-data"{% endif %}>
<fieldset class='module aligned'>
<h2>PUT {{ name }}</h2>
<input type="hidden" name="{{ METHOD_PARAM }}" value="PUT" />
{% csrf_token %}
{{ put_form.non_field_errors }}
{% for field in put_form %}
<div class='form-row'>
{{ field.label_tag }}
{{ field }}
<span class='help'>{{ field.help_text }}</span>
{{ field.errors }}
</div>
{% endfor %}
<div class='submit-row' style='margin: 0; border: 0'>
<input type="submit" value="PUT" class="default" />
</div>
</fieldset>
</form>
{% endif %}
{% if 'DELETE' in view.allowed_methods %}
<form action="{{ request.get_full_path }}" method="post">
<fieldset class='module aligned'>
<h2>DELETE {{ name }}</h2>
{% csrf_token %}
<input type="hidden" name="{{ METHOD_PARAM }}" value="DELETE" />
<div class='submit-row' style='margin: 0; border: 0'>
<input type="submit" value="DELETE" class="default" />
</div>
</fieldset>
</form>
{% endif %}
{% endif %}
</div>
<!-- END content-main -->
</div>
<!-- END Content -->
{% block footer %}<div id="footer"></div>{% endblock %}
</div>
</body>
</html>

View File

@@ -0,0 +1,44 @@
{% load static %}
<html>
<head>
<link rel="stylesheet" type="text/css" href='{% get_static_prefix %}djangorestframework/css/style.css'/>
</head>
<body class="login">
<div id="container">
<div id="header">
<div id="branding">
<h1 id="site-name">Django REST framework</h1>
</div>
</div>
<div id="content" class="colM">
<div id="content-main">
<form method="post" action="{% url djangorestframework.utils.staticviews.api_login %}" id="login-form">
{% csrf_token %}
<div class="form-row">
<label for="id_username">Username:</label> {{ form.username }}
</div>
<div class="form-row">
<label for="id_password">Password:</label> {{ form.password }}
<input type="hidden" name="next" value="{{ next }}" />
</div>
<div class="form-row">
<label>&nbsp;</label><input type="submit" value="Log in">
</div>
</form>
<script type="text/javascript">
document.getElementById('id_username').focus()
</script>
</div>
<br class="clear">
</div>
<div id="footer"></div>
</div>
</body>
</html>

View File

@@ -0,0 +1,10 @@
from django.template import Library
from urlobject import URLObject
register = Library()
def add_query_param(url, param):
return unicode(URLObject(url).with_query(param))
register.filter('add_query_param', add_query_param)

View File

@@ -0,0 +1,100 @@
"""Adds the custom filter 'urlize_quoted_links'
This is identical to the built-in filter 'urlize' with the exception that
single and double quotes are permitted as leading or trailing punctuation.
"""
# Almost all of this code is copied verbatim from django.utils.html
# LEADING_PUNCTUATION and TRAILING_PUNCTUATION have been modified
import re
import string
from django.utils.safestring import SafeData, mark_safe
from django.utils.encoding import force_unicode
from django.utils.http import urlquote
from django.utils.html import escape
from django import template
# Configuration for urlize() function.
LEADING_PUNCTUATION = ['(', '<', '&lt;', '"', "'"]
TRAILING_PUNCTUATION = ['.', ',', ')', '>', '\n', '&gt;', '"', "'"]
# List of possible strings used for bullets in bulleted lists.
DOTS = ['&middot;', '*', '\xe2\x80\xa2', '&#149;', '&bull;', '&#8226;']
unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)')
word_split_re = re.compile(r'(\s+)')
punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % \
('|'.join([re.escape(x) for x in LEADING_PUNCTUATION]),
'|'.join([re.escape(x) for x in TRAILING_PUNCTUATION])))
simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$')
link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+')
html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE)
hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL)
trailing_empty_content_re = re.compile(r'(?:<p>(?:&nbsp;|\s|<br \/>)*?</p>\s*)+\Z')
def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True):
"""
Converts any URLs in text into clickable links.
Works on http://, https://, www. links and links ending in .org, .net or
.com. Links can have trailing punctuation (periods, commas, close-parens)
and leading punctuation (opening parens) and it'll still do the right
thing.
If trim_url_limit is not None, the URLs in link text longer than this limit
will truncated to trim_url_limit-3 characters and appended with an elipsis.
If nofollow is True, the URLs in link text will get a rel="nofollow"
attribute.
If autoescape is True, the link text and URLs will get autoescaped.
"""
trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x
safe_input = isinstance(text, SafeData)
words = word_split_re.split(force_unicode(text))
nofollow_attr = nofollow and ' rel="nofollow"' or ''
for i, word in enumerate(words):
match = None
if '.' in word or '@' in word or ':' in word:
match = punctuation_re.match(word)
if match:
lead, middle, trail = match.groups()
# Make URL we want to point to.
url = None
if middle.startswith('http://') or middle.startswith('https://'):
url = middle
elif middle.startswith('www.') or ('@' not in middle and \
middle and middle[0] in string.ascii_letters + string.digits and \
(middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))):
url = 'http://%s' % middle
elif '@' in middle and not ':' in middle and simple_email_re.match(middle):
url = 'mailto:%s' % middle
nofollow_attr = ''
# Make link.
if url:
trimmed = trim_url(middle)
if autoescape and not safe_input:
lead, trail = escape(lead), escape(trail)
url, trimmed = escape(url), escape(trimmed)
middle = '<a href="%s"%s>%s</a>' % (url, nofollow_attr, trimmed)
words[i] = mark_safe('%s%s%s' % (lead, middle, trail))
else:
if safe_input:
words[i] = mark_safe(word)
elif autoescape:
words[i] = escape(word)
elif safe_input:
words[i] = mark_safe(word)
elif autoescape:
words[i] = escape(word)
return u''.join(words)
#urlize_quoted_links.needs_autoescape = True
urlize_quoted_links.is_safe = True
# Register urlize_quoted_links as a custom filter
# http://docs.djangoproject.com/en/dev/howto/custom-template-tags/
register = template.Library()
register.filter(urlize_quoted_links)

View File

@@ -0,0 +1,13 @@
"""Force import of all modules in this package in order to get the standard test runner to pick up the tests. Yowzers."""
import os
modules = [filename.rsplit('.', 1)[0]
for filename in os.listdir(os.path.dirname(__file__))
if filename.endswith('.py') and not filename.startswith('_')]
__test__ = dict()
for module in modules:
exec("from djangorestframework.tests.%s import __doc__ as module_doc" % module)
exec("from djangorestframework.tests.%s import *" % module)
__test__[module] = module_doc or ""

View File

@@ -0,0 +1,65 @@
from django.test import TestCase
from djangorestframework.compat import RequestFactory
from djangorestframework.views import View
# See: http://www.useragentstring.com/
MSIE_9_USER_AGENT = 'Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))'
MSIE_8_USER_AGENT = 'Mozilla/5.0 (compatible; MSIE 8.0; Windows NT 5.2; Trident/4.0; Media Center PC 4.0; SLCC1; .NET CLR 3.0.04320)'
MSIE_7_USER_AGENT = 'Mozilla/5.0 (Windows; U; MSIE 7.0; Windows NT 6.0; en-US)'
FIREFOX_4_0_USER_AGENT = 'Mozilla/5.0 (Windows; U; Windows NT 6.1; ru; rv:1.9.2.3) Gecko/20100401 Firefox/4.0 (.NET CLR 3.5.30729)'
CHROME_11_0_USER_AGENT = 'Mozilla/5.0 (Windows; U; Windows NT 6.1; en-US) AppleWebKit/534.17 (KHTML, like Gecko) Chrome/11.0.655.0 Safari/534.17'
SAFARI_5_0_USER_AGENT = 'Mozilla/5.0 (X11; U; Linux x86_64; en-ca) AppleWebKit/531.2+ (KHTML, like Gecko) Version/5.0 Safari/531.2+'
OPERA_11_0_MSIE_USER_AGENT = 'Mozilla/4.0 (compatible; MSIE 8.0; X11; Linux x86_64; pl) Opera 11.00'
OPERA_11_0_OPERA_USER_AGENT = 'Opera/9.80 (X11; Linux x86_64; U; pl) Presto/2.7.62 Version/11.00'
class UserAgentMungingTest(TestCase):
"""We need to fake up the accept headers when we deal with MSIE. Blergh.
http://www.gethifi.com/blog/browser-rest-http-accept-headers"""
def setUp(self):
class MockView(View):
permissions = ()
def get(self, request):
return {'a':1, 'b':2, 'c':3}
self.req = RequestFactory()
self.MockView = MockView
self.view = MockView.as_view()
def test_munge_msie_accept_header(self):
"""Send MSIE user agent strings and ensure that we get an HTML response,
even if we set a */* accept header."""
for user_agent in (MSIE_9_USER_AGENT,
MSIE_8_USER_AGENT,
MSIE_7_USER_AGENT):
req = self.req.get('/', HTTP_ACCEPT='*/*', HTTP_USER_AGENT=user_agent)
resp = self.view(req)
self.assertEqual(resp['Content-Type'], 'text/html')
def test_dont_rewrite_msie_accept_header(self):
"""Turn off _IGNORE_IE_ACCEPT_HEADER, send MSIE user agent strings and ensure
that we get a JSON response if we set a */* accept header."""
view = self.MockView.as_view(_IGNORE_IE_ACCEPT_HEADER=False)
for user_agent in (MSIE_9_USER_AGENT,
MSIE_8_USER_AGENT,
MSIE_7_USER_AGENT):
req = self.req.get('/', HTTP_ACCEPT='*/*', HTTP_USER_AGENT=user_agent)
resp = view(req)
self.assertEqual(resp['Content-Type'], 'application/json')
def test_dont_munge_nice_browsers_accept_header(self):
"""Send Non-MSIE user agent strings and ensure that we get a JSON response,
if we set a */* Accept header. (Other browsers will correctly set the Accept header)"""
for user_agent in (FIREFOX_4_0_USER_AGENT,
CHROME_11_0_USER_AGENT,
SAFARI_5_0_USER_AGENT,
OPERA_11_0_MSIE_USER_AGENT,
OPERA_11_0_OPERA_USER_AGENT):
req = self.req.get('/', HTTP_ACCEPT='*/*', HTTP_USER_AGENT=user_agent)
resp = self.view(req)
self.assertEqual(resp['Content-Type'], 'application/json')

View File

@@ -0,0 +1,105 @@
from django.conf.urls.defaults import patterns
from django.contrib.auth.models import User
from django.test import Client, TestCase
from django.utils import simplejson as json
from djangorestframework.views import View
from djangorestframework import permissions
import base64
class MockView(View):
permissions = (permissions.IsAuthenticated,)
def post(self, request):
return {'a': 1, 'b': 2, 'c': 3}
def put(self, request):
return {'a': 1, 'b': 2, 'c': 3}
urlpatterns = patterns('',
(r'^$', MockView.as_view()),
)
class BasicAuthTests(TestCase):
"""Basic authentication"""
urls = 'djangorestframework.tests.authentication'
def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password)
def test_post_form_passing_basic_auth(self):
"""Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_json_passing_basic_auth(self):
"""Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_form_failing_basic_auth(self):
"""Ensure POSTing form over basic auth without correct credentials fails"""
response = self.csrf_client.post('/', {'example': 'example'})
self.assertEqual(response.status_code, 403)
def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails"""
response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
self.assertEqual(response.status_code, 403)
class SessionAuthTests(TestCase):
"""User session authentication"""
urls = 'djangorestframework.tests.authentication'
def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True)
self.non_csrf_client = Client(enforce_csrf_checks=False)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password)
def tearDown(self):
self.csrf_client.logout()
def test_post_form_session_auth_failing_csrf(self):
"""
Ensure POSTing form over session authentication without CSRF token fails.
"""
self.csrf_client.login(username=self.username, password=self.password)
response = self.csrf_client.post('/', {'example': 'example'})
self.assertEqual(response.status_code, 403)
def test_post_form_session_auth_passing(self):
"""
Ensure POSTing form over session authentication with logged in user and CSRF token passes.
"""
self.non_csrf_client.login(username=self.username, password=self.password)
response = self.non_csrf_client.post('/', {'example': 'example'})
self.assertEqual(response.status_code, 200)
def test_put_form_session_auth_passing(self):
"""
Ensure PUTting form over session authentication with logged in user and CSRF token passes.
"""
self.non_csrf_client.login(username=self.username, password=self.password)
response = self.non_csrf_client.put('/', {'example': 'example'})
self.assertEqual(response.status_code, 200)
def test_post_form_session_auth_failing(self):
"""
Ensure POSTing form over session authentication without logged in user fails.
"""
response = self.csrf_client.post('/', {'example': 'example'})
self.assertEqual(response.status_code, 403)

View File

@@ -0,0 +1,67 @@
from django.conf.urls.defaults import patterns, url
from django.test import TestCase
from djangorestframework.utils.breadcrumbs import get_breadcrumbs
from djangorestframework.views import View
class Root(View):
pass
class ResourceRoot(View):
pass
class ResourceInstance(View):
pass
class NestedResourceRoot(View):
pass
class NestedResourceInstance(View):
pass
urlpatterns = patterns('',
url(r'^$', Root.as_view()),
url(r'^resource/$', ResourceRoot.as_view()),
url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()),
url(r'^resource/(?P<key>[0-9]+)/$', NestedResourceRoot.as_view()),
url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$', NestedResourceInstance.as_view()),
)
class BreadcrumbTests(TestCase):
"""Tests the breadcrumb functionality used by the HTML renderer."""
urls = 'djangorestframework.tests.breadcrumbs'
def test_root_breadcrumbs(self):
url = '/'
self.assertEqual(get_breadcrumbs(url), [('Root', '/')])
def test_resource_root_breadcrumbs(self):
url = '/resource/'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
('Resource Root', '/resource/')])
def test_resource_instance_breadcrumbs(self):
url = '/resource/123'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
('Resource Root', '/resource/'),
('Resource Instance', '/resource/123')])
def test_nested_resource_breadcrumbs(self):
url = '/resource/123/'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
('Resource Root', '/resource/'),
('Resource Instance', '/resource/123'),
('Nested Resource Root', '/resource/123/')])
def test_nested_resource_instance_breadcrumbs(self):
url = '/resource/123/abc'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
('Resource Root', '/resource/'),
('Resource Instance', '/resource/123'),
('Nested Resource Root', '/resource/123/'),
('Nested Resource Instance', '/resource/123/abc')])
def test_broken_url_breadcrumbs_handled_gracefully(self):
url = '/foobar'
self.assertEqual(get_breadcrumbs(url), [('Root', '/')])

View File

@@ -0,0 +1,233 @@
"""
Tests for content parsing, and form-overloaded content parsing.
"""
from django.conf.urls.defaults import patterns
from django.contrib.auth.models import User
from django.test import TestCase, Client
from djangorestframework import status
from djangorestframework.authentication import UserLoggedInAuthentication
from djangorestframework.compat import RequestFactory, unittest
from djangorestframework.mixins import RequestMixin
from djangorestframework.parsers import FormParser, MultiPartParser, \
PlainTextParser, JSONParser
from djangorestframework.response import Response
from djangorestframework.views import View
class MockView(View):
authentication = (UserLoggedInAuthentication,)
def post(self, request):
if request.POST.get('example') is not None:
return Response(status.HTTP_200_OK)
return Response(status.INTERNAL_SERVER_ERROR)
urlpatterns = patterns('',
(r'^$', MockView.as_view()),
)
class TestContentParsing(TestCase):
def setUp(self):
self.req = RequestFactory()
def ensure_determines_no_content_GET(self, view):
"""Ensure view.DATA returns None for GET request with no content."""
view.request = self.req.get('/')
self.assertEqual(view.DATA, None)
def ensure_determines_no_content_HEAD(self, view):
"""Ensure view.DATA returns None for HEAD request."""
view.request = self.req.head('/')
self.assertEqual(view.DATA, None)
def ensure_determines_form_content_POST(self, view):
"""Ensure view.DATA returns content for POST request with form content."""
form_data = {'qwerty': 'uiop'}
view.parsers = (FormParser, MultiPartParser)
view.request = self.req.post('/', data=form_data)
self.assertEqual(view.DATA.items(), form_data.items())
def ensure_determines_non_form_content_POST(self, view):
"""Ensure view.RAW_CONTENT returns content for POST request with non-form content."""
content = 'qwerty'
content_type = 'text/plain'
view.parsers = (PlainTextParser,)
view.request = self.req.post('/', content, content_type=content_type)
self.assertEqual(view.DATA, content)
def ensure_determines_form_content_PUT(self, view):
"""Ensure view.RAW_CONTENT returns content for PUT request with form content."""
form_data = {'qwerty': 'uiop'}
view.parsers = (FormParser, MultiPartParser)
view.request = self.req.put('/', data=form_data)
self.assertEqual(view.DATA.items(), form_data.items())
def ensure_determines_non_form_content_PUT(self, view):
"""Ensure view.RAW_CONTENT returns content for PUT request with non-form content."""
content = 'qwerty'
content_type = 'text/plain'
view.parsers = (PlainTextParser,)
view.request = self.req.post('/', content, content_type=content_type)
self.assertEqual(view.DATA, content)
def test_standard_behaviour_determines_no_content_GET(self):
"""Ensure view.DATA returns None for GET request with no content."""
self.ensure_determines_no_content_GET(RequestMixin())
def test_standard_behaviour_determines_no_content_HEAD(self):
"""Ensure view.DATA returns None for HEAD request."""
self.ensure_determines_no_content_HEAD(RequestMixin())
def test_standard_behaviour_determines_form_content_POST(self):
"""Ensure view.DATA returns content for POST request with form content."""
self.ensure_determines_form_content_POST(RequestMixin())
def test_standard_behaviour_determines_non_form_content_POST(self):
"""Ensure view.DATA returns content for POST request with non-form content."""
self.ensure_determines_non_form_content_POST(RequestMixin())
def test_standard_behaviour_determines_form_content_PUT(self):
"""Ensure view.DATA returns content for PUT request with form content."""
self.ensure_determines_form_content_PUT(RequestMixin())
def test_standard_behaviour_determines_non_form_content_PUT(self):
"""Ensure view.DATA returns content for PUT request with non-form content."""
self.ensure_determines_non_form_content_PUT(RequestMixin())
def test_overloaded_behaviour_allows_content_tunnelling(self):
"""Ensure request.DATA returns content for overloaded POST request"""
content = 'qwerty'
content_type = 'text/plain'
view = RequestMixin()
form_data = {view._CONTENT_PARAM: content,
view._CONTENTTYPE_PARAM: content_type}
view.request = self.req.post('/', form_data)
view.parsers = (PlainTextParser,)
self.assertEqual(view.DATA, content)
def test_accessing_post_after_data_form(self):
"""Ensures request.POST can be accessed after request.DATA in form request"""
form_data = {'qwerty': 'uiop'}
view = RequestMixin()
view.parsers = (FormParser, MultiPartParser)
view.request = self.req.post('/', data=form_data)
self.assertEqual(view.DATA.items(), form_data.items())
self.assertEqual(view.request.POST.items(), form_data.items())
@unittest.skip('This test was disabled some time ago for some reason')
def test_accessing_post_after_data_for_json(self):
"""Ensures request.POST can be accessed after request.DATA in json request"""
from django.utils import simplejson as json
data = {'qwerty': 'uiop'}
content = json.dumps(data)
content_type = 'application/json'
view = RequestMixin()
view.parsers = (JSONParser,)
view.request = self.req.post('/', content, content_type=content_type)
self.assertEqual(view.DATA.items(), data.items())
self.assertEqual(view.request.POST.items(), [])
def test_accessing_post_after_data_for_overloaded_json(self):
"""Ensures request.POST can be accessed after request.DATA in overloaded json request"""
from django.utils import simplejson as json
data = {'qwerty': 'uiop'}
content = json.dumps(data)
content_type = 'application/json'
view = RequestMixin()
view.parsers = (JSONParser,)
form_data = {view._CONTENT_PARAM: content,
view._CONTENTTYPE_PARAM: content_type}
view.request = self.req.post('/', data=form_data)
self.assertEqual(view.DATA.items(), data.items())
self.assertEqual(view.request.POST.items(), form_data.items())
def test_accessing_data_after_post_form(self):
"""Ensures request.DATA can be accessed after request.POST in form request"""
form_data = {'qwerty': 'uiop'}
view = RequestMixin()
view.parsers = (FormParser, MultiPartParser)
view.request = self.req.post('/', data=form_data)
self.assertEqual(view.request.POST.items(), form_data.items())
self.assertEqual(view.DATA.items(), form_data.items())
def test_accessing_data_after_post_for_json(self):
"""Ensures request.DATA can be accessed after request.POST in json request"""
from django.utils import simplejson as json
data = {'qwerty': 'uiop'}
content = json.dumps(data)
content_type = 'application/json'
view = RequestMixin()
view.parsers = (JSONParser,)
view.request = self.req.post('/', content, content_type=content_type)
post_items = view.request.POST.items()
self.assertEqual(len(post_items), 1)
self.assertEqual(len(post_items[0]), 2)
self.assertEqual(post_items[0][0], content)
self.assertEqual(view.DATA.items(), data.items())
def test_accessing_data_after_post_for_overloaded_json(self):
"""Ensures request.DATA can be accessed after request.POST in overloaded json request"""
from django.utils import simplejson as json
data = {'qwerty': 'uiop'}
content = json.dumps(data)
content_type = 'application/json'
view = RequestMixin()
view.parsers = (JSONParser,)
form_data = {view._CONTENT_PARAM: content,
view._CONTENTTYPE_PARAM: content_type}
view.request = self.req.post('/', data=form_data)
self.assertEqual(view.request.POST.items(), form_data.items())
self.assertEqual(view.DATA.items(), data.items())
class TestContentParsingWithAuthentication(TestCase):
urls = 'djangorestframework.tests.content'
def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password)
self.req = RequestFactory()
def test_user_logged_in_authentication_has_post_when_not_logged_in(self):
"""Ensures request.POST exists after UserLoggedInAuthentication when user doesn't log in"""
content = {'example': 'example'}
response = self.client.post('/', content)
self.assertEqual(status.HTTP_200_OK, response.status_code, "POST data is malformed")
response = self.csrf_client.post('/', content)
self.assertEqual(status.HTTP_200_OK, response.status_code, "POST data is malformed")
# def test_user_logged_in_authentication_has_post_when_logged_in(self):
# """Ensures request.POST exists after UserLoggedInAuthentication when user does log in"""
# self.client.login(username='john', password='password')
# self.csrf_client.login(username='john', password='password')
# content = {'example': 'example'}
# response = self.client.post('/', content)
# self.assertEqual(status.OK, response.status_code, "POST data is malformed")
# response = self.csrf_client.post('/', content)
# self.assertEqual(status.OK, response.status_code, "POST data is malformed")

View File

@@ -0,0 +1,111 @@
from django.test import TestCase
from djangorestframework.views import View
from djangorestframework.compat import apply_markdown
# We check that docstrings get nicely un-indented.
DESCRIPTION = """an example docstring
====================
* list
* list
another header
--------------
code block
indented
# hash style header #"""
# If markdown is installed we also test it's working
# (and that our wrapped forces '=' to h2 and '-' to h3)
# We support markdown < 2.1 and markdown >= 2.1
MARKED_DOWN_lt_21 = """<h2>an example docstring</h2>
<ul>
<li>list</li>
<li>list</li>
</ul>
<h3>another header</h3>
<pre><code>code block
</code></pre>
<p>indented</p>
<h2 id="hash_style_header">hash style header</h2>"""
MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>
<ul>
<li>list</li>
<li>list</li>
</ul>
<h3 id="another-header">another header</h3>
<pre><code>code block
</code></pre>
<p>indented</p>
<h2 id="hash-style-header">hash style header</h2>"""
class TestViewNamesAndDescriptions(TestCase):
def test_resource_name_uses_classname_by_default(self):
"""Ensure Resource names are based on the classname by default."""
class MockView(View):
pass
self.assertEquals(MockView().get_name(), 'Mock')
def test_resource_name_can_be_set_explicitly(self):
"""Ensure Resource names can be set using the 'get_name' method."""
example = 'Some Other Name'
class MockView(View):
def get_name(self):
return example
self.assertEquals(MockView().get_name(), example)
def test_resource_description_uses_docstring_by_default(self):
"""Ensure Resource names are based on the docstring by default."""
class MockView(View):
"""an example docstring
====================
* list
* list
another header
--------------
code block
indented
# hash style header #"""
self.assertEquals(MockView().get_description(), DESCRIPTION)
def test_resource_description_can_be_set_explicitly(self):
"""Ensure Resource descriptions can be set using the 'get_description' method."""
example = 'Some other description'
class MockView(View):
"""docstring"""
def get_description(self):
return example
self.assertEquals(MockView().get_description(), example)
def test_resource_description_does_not_require_docstring(self):
"""Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method."""
example = 'Some other description'
class MockView(View):
def get_description(self):
return example
self.assertEquals(MockView().get_description(), example)
def test_resource_description_can_be_empty(self):
"""Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string."""
class MockView(View):
pass
self.assertEquals(MockView().get_description(), '')
def test_markdown(self):
"""Ensure markdown to HTML works as expected"""
if apply_markdown:
gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21
lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21
self.assertTrue(gte_21_match or lt_21_match)

View File

@@ -0,0 +1,32 @@
from django.test import TestCase
from django import forms
from djangorestframework.compat import RequestFactory
from djangorestframework.views import View
from djangorestframework.resources import FormResource
import StringIO
class UploadFilesTests(TestCase):
"""Check uploading of files"""
def setUp(self):
self.factory = RequestFactory()
def test_upload_file(self):
class FileForm(forms.Form):
file = forms.FileField()
class MockView(View):
permissions = ()
form = FileForm
def post(self, request, *args, **kwargs):
return {'FILE_NAME': self.CONTENT['file'].name,
'FILE_CONTENT': self.CONTENT['file'].read()}
file = StringIO.StringIO('stuff')
file.name = 'stuff.txt'
request = self.factory.post('/', {'file': file})
view = MockView.as_view()
response = view(request)
self.assertEquals(response.content, '{"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"}')

View File

@@ -0,0 +1,32 @@
from django.test import TestCase
from djangorestframework.compat import RequestFactory
from djangorestframework.mixins import RequestMixin
class TestMethodOverloading(TestCase):
def setUp(self):
self.req = RequestFactory()
def test_standard_behaviour_determines_GET(self):
"""GET requests identified"""
view = RequestMixin()
view.request = self.req.get('/')
self.assertEqual(view.method, 'GET')
def test_standard_behaviour_determines_POST(self):
"""POST requests identified"""
view = RequestMixin()
view.request = self.req.post('/')
self.assertEqual(view.method, 'POST')
def test_overloaded_POST_behaviour_determines_overloaded_method(self):
"""POST requests can be overloaded to another method by setting a reserved form field"""
view = RequestMixin()
view.request = self.req.post('/', {view._METHOD_PARAM: 'DELETE'})
self.assertEqual(view.method, 'DELETE')
def test_HEAD_is_a_valid_method(self):
"""HEAD requests identified"""
view = RequestMixin()
view.request = self.req.head('/')
self.assertEqual(view.method, 'HEAD')

View File

@@ -0,0 +1,291 @@
"""Tests for the mixin module"""
from django.test import TestCase
from django.utils import simplejson as json
from djangorestframework import status
from djangorestframework.compat import RequestFactory
from django.contrib.auth.models import Group, User
from djangorestframework.mixins import CreateModelMixin, PaginatorMixin, ReadModelMixin
from djangorestframework.resources import ModelResource
from djangorestframework.response import Response, ErrorResponse
from djangorestframework.tests.models import CustomUser
from djangorestframework.tests.testcases import TestModelsTestCase
from djangorestframework.views import View
class TestModelRead(TestModelsTestCase):
"""Tests on ReadModelMixin"""
def setUp(self):
super(TestModelRead, self).setUp()
self.req = RequestFactory()
def test_read(self):
Group.objects.create(name='other group')
group = Group.objects.create(name='my group')
class GroupResource(ModelResource):
model = Group
request = self.req.get('/groups')
mixin = ReadModelMixin()
mixin.resource = GroupResource
response = mixin.get(request, id=group.id)
self.assertEquals(group.name, response.name)
def test_read_404(self):
class GroupResource(ModelResource):
model = Group
request = self.req.get('/groups')
mixin = ReadModelMixin()
mixin.resource = GroupResource
self.assertRaises(ErrorResponse, mixin.get, request, id=12345)
class TestModelCreation(TestModelsTestCase):
"""Tests on CreateModelMixin"""
def setUp(self):
super(TestModelsTestCase, self).setUp()
self.req = RequestFactory()
def test_creation(self):
self.assertEquals(0, Group.objects.count())
class GroupResource(ModelResource):
model = Group
form_data = {'name': 'foo'}
request = self.req.post('/groups', data=form_data)
mixin = CreateModelMixin()
mixin.resource = GroupResource
mixin.CONTENT = form_data
response = mixin.post(request)
self.assertEquals(1, Group.objects.count())
self.assertEquals('foo', response.cleaned_content.name)
def test_creation_with_m2m_relation(self):
class UserResource(ModelResource):
model = User
def url(self, instance):
return "/users/%i" % instance.id
group = Group(name='foo')
group.save()
form_data = {
'username': 'bar',
'password': 'baz',
'groups': [group.id]
}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group]
mixin = CreateModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.post(request)
self.assertEquals(1, User.objects.count())
self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo', response.cleaned_content.groups.all()[0].name)
def test_creation_with_m2m_relation_through(self):
"""
Tests creation where the m2m relation uses a through table
"""
class UserResource(ModelResource):
model = CustomUser
def url(self, instance):
return "/customusers/%i" % instance.id
form_data = {'username': 'bar0', 'groups': []}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = []
mixin = CreateModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.post(request)
self.assertEquals(1, CustomUser.objects.count())
self.assertEquals(0, response.cleaned_content.groups.count())
group = Group(name='foo1')
group.save()
form_data = {'username': 'bar1', 'groups': [group.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group]
mixin = CreateModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.post(request)
self.assertEquals(2, CustomUser.objects.count())
self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
group2 = Group(name='foo2')
group2.save()
form_data = {'username': 'bar2', 'groups': [group.id, group2.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group, group2]
mixin = CreateModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.post(request)
self.assertEquals(3, CustomUser.objects.count())
self.assertEquals(2, response.cleaned_content.groups.count())
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
self.assertEquals('foo2', response.cleaned_content.groups.all()[1].name)
class MockPaginatorView(PaginatorMixin, View):
total = 60
def get(self, request):
return range(0, self.total)
def post(self, request):
return Response(status.HTTP_201_CREATED, {'status': 'OK'})
class TestPagination(TestCase):
def setUp(self):
self.req = RequestFactory()
def test_default_limit(self):
""" Tests if pagination works without overwriting the limit """
request = self.req.get('/paginator')
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(MockPaginatorView.total, content['total'])
self.assertEqual(MockPaginatorView.limit, content['per_page'])
self.assertEqual(range(0, MockPaginatorView.limit), content['results'])
def test_overwriting_limit(self):
""" Tests if the limit can be overwritten """
limit = 10
request = self.req.get('/paginator')
response = MockPaginatorView.as_view(limit=limit)(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(content['per_page'], limit)
self.assertEqual(range(0, limit), content['results'])
def test_limit_param(self):
""" Tests if the client can set the limit """
from math import ceil
limit = 5
num_pages = int(ceil(MockPaginatorView.total / float(limit)))
request = self.req.get('/paginator/?limit=%d' % limit)
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(MockPaginatorView.total, content['total'])
self.assertEqual(limit, content['per_page'])
self.assertEqual(num_pages, content['pages'])
def test_exceeding_limit(self):
""" Makes sure the client cannot exceed the default limit """
from math import ceil
limit = MockPaginatorView.limit + 10
num_pages = int(ceil(MockPaginatorView.total / float(limit)))
request = self.req.get('/paginator/?limit=%d' % limit)
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(MockPaginatorView.total, content['total'])
self.assertNotEqual(limit, content['per_page'])
self.assertNotEqual(num_pages, content['pages'])
self.assertEqual(MockPaginatorView.limit, content['per_page'])
def test_only_works_for_get(self):
""" Pagination should only work for GET requests """
request = self.req.post('/paginator', data={'content': 'spam'})
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(None, content.get('per_page'))
self.assertEqual('OK', content['status'])
def test_non_int_page(self):
""" Tests that it can handle invalid values """
request = self.req.get('/paginator/?page=spam')
response = MockPaginatorView.as_view()(request)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_page_range(self):
""" Tests that the page range is handle correctly """
request = self.req.get('/paginator/?page=0')
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
request = self.req.get('/paginator/')
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(range(0, MockPaginatorView.limit), content['results'])
num_pages = content['pages']
request = self.req.get('/paginator/?page=%d' % num_pages)
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(range(MockPaginatorView.limit*(num_pages-1), MockPaginatorView.total), content['results'])
request = self.req.get('/paginator/?page=%d' % (num_pages + 1,))
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_existing_query_parameters_are_preserved(self):
""" Tests that existing query parameters are preserved when
generating next/previous page links """
request = self.req.get('/paginator/?foo=bar&another=something')
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertTrue('foo=bar' in content['next'])
self.assertTrue('another=something' in content['next'])
self.assertTrue('page=2' in content['next'])
def test_duplicate_parameters_are_not_created(self):
""" Regression: ensure duplicate "page" parameters are not added to
paginated URLs. So page 1 should contain ?page=2, not ?page=1&page=2 """
request = self.req.get('/paginator/?page=1')
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertTrue('page=2' in content['next'])
self.assertFalse('page=1' in content['next'])

View File

@@ -0,0 +1,28 @@
from django.db import models
from django.contrib.auth.models import Group
class CustomUser(models.Model):
"""
A custom user model, which uses a 'through' table for the foreign key
"""
username = models.CharField(max_length=255, unique=True)
groups = models.ManyToManyField(
to=Group, blank=True, null=True, through='UserGroupMap'
)
@models.permalink
def get_absolute_url(self):
return ('custom_user', (), {
'pk': self.id
})
class UserGroupMap(models.Model):
user = models.ForeignKey(to=CustomUser)
group = models.ForeignKey(to=Group)
@models.permalink
def get_absolute_url(self):
return ('user_group_map', (), {
'pk': self.id
})

View File

@@ -0,0 +1,87 @@
from django.conf.urls.defaults import patterns, url
from django.test import TestCase
from django.forms import ModelForm
from django.contrib.auth.models import Group, User
from djangorestframework.resources import ModelResource
from djangorestframework.views import ListOrCreateModelView, InstanceModelView
from djangorestframework.tests.models import CustomUser
from djangorestframework.tests.testcases import TestModelsTestCase
class GroupResource(ModelResource):
model = Group
class UserForm(ModelForm):
class Meta:
model = User
exclude = ('last_login', 'date_joined')
class UserResource(ModelResource):
model = User
form = UserForm
class CustomUserResource(ModelResource):
model = CustomUser
urlpatterns = patterns('',
url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'),
url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)),
url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'),
url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)),
url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'),
url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)),
)
class ModelViewTests(TestModelsTestCase):
"""Test the model views djangorestframework provides"""
urls = 'djangorestframework.tests.modelviews'
def test_creation(self):
"""Ensure that a model object can be created"""
self.assertEqual(0, Group.objects.count())
response = self.client.post('/groups/', {'name': 'foo'})
self.assertEqual(response.status_code, 201)
self.assertEqual(1, Group.objects.count())
self.assertEqual('foo', Group.objects.all()[0].name)
def test_creation_with_m2m_relation(self):
"""Ensure that a model object with a m2m relation can be created"""
group = Group(name='foo')
group.save()
self.assertEqual(0, User.objects.count())
response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]})
self.assertEqual(response.status_code, 201)
self.assertEqual(1, User.objects.count())
user = User.objects.all()[0]
self.assertEqual('bar', user.username)
self.assertEqual('baz', user.password)
self.assertEqual(1, user.groups.count())
group = user.groups.all()[0]
self.assertEqual('foo', group.name)
def test_creation_with_m2m_relation_through(self):
"""
Ensure that a model object with a m2m relation can be created where that
relation uses a through table
"""
group = Group(name='foo')
group.save()
self.assertEqual(0, User.objects.count())
response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]})
self.assertEqual(response.status_code, 201)
self.assertEqual(1, CustomUser.objects.count())
user = CustomUser.objects.all()[0]
self.assertEqual('bar', user.username)
self.assertEqual(1, user.groups.count())
group = user.groups.all()[0]
self.assertEqual('foo', group.name)

View File

@@ -0,0 +1,212 @@
import time
from django.conf.urls.defaults import patterns, url, include
from django.contrib.auth.models import User
from django.test import Client, TestCase
from djangorestframework.views import View
# Since oauth2 / django-oauth-plus are optional dependancies, we don't want to
# always run these tests.
# Unfortunatly we can't skip tests easily until 2.7, se we'll just do this for now.
try:
import oauth2 as oauth
from oauth_provider.decorators import oauth_required
from oauth_provider.models import Resource, Consumer, Token
except ImportError:
pass
else:
# Alrighty, we're good to go here.
class ClientView(View):
def get(self, request):
return {'resource': 'Protected!'}
urlpatterns = patterns('',
url(r'^$', oauth_required(ClientView.as_view())),
url(r'^oauth/', include('oauth_provider.urls')),
url(r'^accounts/login/$', 'djangorestframework.utils.staticviews.api_login'),
)
class OAuthTests(TestCase):
"""
OAuth authentication:
* the user would like to access his API data from a third-party website
* the third-party website proposes a link to get that API data
* the user is redirected to the API and must log in if not authenticated
* the API displays a webpage to confirm that the user trusts the third-party website
* if confirmed, the user is redirected to the third-party website through the callback view
* the third-party website is able to retrieve data from the API
"""
urls = 'djangorestframework.tests.oauthentication'
def setUp(self):
self.client = Client()
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password)
# OAuth requirements
self.resource = Resource(name='data', url='/')
self.resource.save()
self.CONSUMER_KEY = 'dpf43f3p2l4k3l03'
self.CONSUMER_SECRET = 'kd94hf93k423kf44'
self.consumer = Consumer(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
name='api.example.com', user=self.user)
self.consumer.save()
def test_oauth_invalid_and_anonymous_access(self):
"""
Verify that the resource is protected and the OAuth authorization view
require the user to be logged in.
"""
response = self.client.get('/')
self.assertEqual(response.content, 'Invalid request parameters.')
self.assertEqual(response.status_code, 401)
response = self.client.get('/oauth/authorize/', follow=True)
self.assertRedirects(response, '/accounts/login/?next=/oauth/authorize/')
def test_oauth_authorize_access(self):
"""
Verify that once logged in, the user can access the authorization page
but can't display the page because the request token is not specified.
"""
self.client.login(username=self.username, password=self.password)
response = self.client.get('/oauth/authorize/', follow=True)
self.assertEqual(response.content, 'No request token specified.')
def _create_request_token_parameters(self):
"""
A shortcut to create request's token parameters.
"""
return {
'oauth_consumer_key': self.CONSUMER_KEY,
'oauth_signature_method': 'PLAINTEXT',
'oauth_signature': '%s&' % self.CONSUMER_SECRET,
'oauth_timestamp': str(int(time.time())),
'oauth_nonce': 'requestnonce',
'oauth_version': '1.0',
'oauth_callback': 'http://api.example.com/request_token_ready',
'scope': 'data',
}
def test_oauth_request_token_retrieval(self):
"""
Verify that the request token can be retrieved by the server.
"""
response = self.client.get("/oauth/request_token/",
self._create_request_token_parameters())
self.assertEqual(response.status_code, 200)
token = list(Token.objects.all())[-1]
self.failIf(token.key not in response.content)
self.failIf(token.secret not in response.content)
def test_oauth_user_request_authorization(self):
"""
Verify that the user can access the authorization page once logged in
and the request token has been retrieved.
"""
# Setup
response = self.client.get("/oauth/request_token/",
self._create_request_token_parameters())
token = list(Token.objects.all())[-1]
# Starting the test here
self.client.login(username=self.username, password=self.password)
parameters = {'oauth_token': token.key,}
response = self.client.get("/oauth/authorize/", parameters)
self.assertEqual(response.status_code, 200)
self.failIf(not response.content.startswith('Fake authorize view for api.example.com with params: oauth_token='))
self.assertEqual(token.is_approved, 0)
parameters['authorize_access'] = 1 # fake authorization by the user
response = self.client.post("/oauth/authorize/", parameters)
self.assertEqual(response.status_code, 302)
self.failIf(not response['Location'].startswith('http://api.example.com/request_token_ready?oauth_verifier='))
token = Token.objects.get(key=token.key)
self.failIf(token.key not in response['Location'])
self.assertEqual(token.is_approved, 1)
def _create_access_token_parameters(self, token):
"""
A shortcut to create access' token parameters.
"""
return {
'oauth_consumer_key': self.CONSUMER_KEY,
'oauth_token': token.key,
'oauth_signature_method': 'PLAINTEXT',
'oauth_signature': '%s&%s' % (self.CONSUMER_SECRET, token.secret),
'oauth_timestamp': str(int(time.time())),
'oauth_nonce': 'accessnonce',
'oauth_version': '1.0',
'oauth_verifier': token.verifier,
'scope': 'data',
}
def test_oauth_access_token_retrieval(self):
"""
Verify that the request token can be retrieved by the server.
"""
# Setup
response = self.client.get("/oauth/request_token/",
self._create_request_token_parameters())
token = list(Token.objects.all())[-1]
self.client.login(username=self.username, password=self.password)
parameters = {'oauth_token': token.key,}
response = self.client.get("/oauth/authorize/", parameters)
parameters['authorize_access'] = 1 # fake authorization by the user
response = self.client.post("/oauth/authorize/", parameters)
token = Token.objects.get(key=token.key)
# Starting the test here
response = self.client.get("/oauth/access_token/", self._create_access_token_parameters(token))
self.assertEqual(response.status_code, 200)
self.failIf(not response.content.startswith('oauth_token_secret='))
access_token = list(Token.objects.filter(token_type=Token.ACCESS))[-1]
self.failIf(access_token.key not in response.content)
self.failIf(access_token.secret not in response.content)
self.assertEqual(access_token.user.username, 'john')
def _create_access_parameters(self, access_token):
"""
A shortcut to create access' parameters.
"""
parameters = {
'oauth_consumer_key': self.CONSUMER_KEY,
'oauth_token': access_token.key,
'oauth_signature_method': 'HMAC-SHA1',
'oauth_timestamp': str(int(time.time())),
'oauth_nonce': 'accessresourcenonce',
'oauth_version': '1.0',
}
oauth_request = oauth.Request.from_token_and_callback(access_token,
http_url='http://testserver/', parameters=parameters)
signature_method = oauth.SignatureMethod_HMAC_SHA1()
signature = signature_method.sign(oauth_request, self.consumer, access_token)
parameters['oauth_signature'] = signature
return parameters
def test_oauth_protected_resource_access(self):
"""
Verify that the request token can be retrieved by the server.
"""
# Setup
response = self.client.get("/oauth/request_token/",
self._create_request_token_parameters())
token = list(Token.objects.all())[-1]
self.client.login(username=self.username, password=self.password)
parameters = {'oauth_token': token.key,}
response = self.client.get("/oauth/authorize/", parameters)
parameters['authorize_access'] = 1 # fake authorization by the user
response = self.client.post("/oauth/authorize/", parameters)
token = Token.objects.get(key=token.key)
response = self.client.get("/oauth/access_token/", self._create_access_token_parameters(token))
access_token = list(Token.objects.filter(token_type=Token.ACCESS))[-1]
# Starting the test here
response = self.client.get("/", self._create_access_token_parameters(access_token))
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, '{"resource": "Protected!"}')

View File

@@ -0,0 +1,11 @@
"""Tests for the djangorestframework package setup."""
from django.test import TestCase
import djangorestframework
class TestVersion(TestCase):
"""Simple sanity test to check the VERSION exists"""
def test_version(self):
"""Ensure the VERSION exists."""
djangorestframework.VERSION

View File

@@ -0,0 +1,210 @@
# """
# ..
# >>> from djangorestframework.parsers import FormParser
# >>> from djangorestframework.compat import RequestFactory
# >>> from djangorestframework.views import View
# >>> from StringIO import StringIO
# >>> from urllib import urlencode
# >>> req = RequestFactory().get('/')
# >>> some_view = View()
# >>> some_view.request = req # Make as if this request had been dispatched
#
# FormParser
# ============
#
# Data flatening
# ----------------
#
# Here is some example data, which would eventually be sent along with a post request :
#
# >>> inpt = urlencode([
# ... ('key1', 'bla1'),
# ... ('key2', 'blo1'), ('key2', 'blo2'),
# ... ])
#
# Default behaviour for :class:`parsers.FormParser`, is to return a single value for each parameter :
#
# >>> (data, files) = FormParser(some_view).parse(StringIO(inpt))
# >>> data == {'key1': 'bla1', 'key2': 'blo1'}
# True
#
# However, you can customize this behaviour by subclassing :class:`parsers.FormParser`, and overriding :meth:`parsers.FormParser.is_a_list` :
#
# >>> class MyFormParser(FormParser):
# ...
# ... def is_a_list(self, key, val_list):
# ... return len(val_list) > 1
#
# This new parser only flattens the lists of parameters that contain a single value.
#
# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
# >>> data == {'key1': 'bla1', 'key2': ['blo1', 'blo2']}
# True
#
# .. note:: The same functionality is available for :class:`parsers.MultiPartParser`.
#
# Submitting an empty list
# --------------------------
#
# When submitting an empty select multiple, like this one ::
#
# <select multiple="multiple" name="key2"></select>
#
# The browsers usually strip the parameter completely. A hack to avoid this, and therefore being able to submit an empty select multiple, is to submit a value that tells the server that the list is empty ::
#
# <select multiple="multiple" name="key2"><option value="_empty"></select>
#
# :class:`parsers.FormParser` provides the server-side implementation for this hack. Considering the following posted data :
#
# >>> inpt = urlencode([
# ... ('key1', 'blo1'), ('key1', '_empty'),
# ... ('key2', '_empty'),
# ... ])
#
# :class:`parsers.FormParser` strips the values ``_empty`` from all the lists.
#
# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
# >>> data == {'key1': 'blo1'}
# True
#
# Oh ... but wait a second, the parameter ``key2`` isn't even supposed to be a list, so the parser just stripped it.
#
# >>> class MyFormParser(FormParser):
# ...
# ... def is_a_list(self, key, val_list):
# ... return key == 'key2'
# ...
# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
# >>> data == {'key1': 'blo1', 'key2': []}
# True
#
# Better like that. Note that you can configure something else than ``_empty`` for the empty value by setting :attr:`parsers.FormParser.EMPTY_VALUE`.
# """
# import httplib, mimetypes
# from tempfile import TemporaryFile
# from django.test import TestCase
# from djangorestframework.compat import RequestFactory
# from djangorestframework.parsers import MultiPartParser
# from djangorestframework.views import View
# from StringIO import StringIO
#
# def encode_multipart_formdata(fields, files):
# """For testing multipart parser.
# fields is a sequence of (name, value) elements for regular form fields.
# files is a sequence of (name, filename, value) elements for data to be uploaded as files
# Return (content_type, body)."""
# BOUNDARY = '----------ThIs_Is_tHe_bouNdaRY_$'
# CRLF = '\r\n'
# L = []
# for (key, value) in fields:
# L.append('--' + BOUNDARY)
# L.append('Content-Disposition: form-data; name="%s"' % key)
# L.append('')
# L.append(value)
# for (key, filename, value) in files:
# L.append('--' + BOUNDARY)
# L.append('Content-Disposition: form-data; name="%s"; filename="%s"' % (key, filename))
# L.append('Content-Type: %s' % get_content_type(filename))
# L.append('')
# L.append(value)
# L.append('--' + BOUNDARY + '--')
# L.append('')
# body = CRLF.join(L)
# content_type = 'multipart/form-data; boundary=%s' % BOUNDARY
# return content_type, body
#
# def get_content_type(filename):
# return mimetypes.guess_type(filename)[0] or 'application/octet-stream'
#
#class TestMultiPartParser(TestCase):
# def setUp(self):
# self.req = RequestFactory()
# self.content_type, self.body = encode_multipart_formdata([('key1', 'val1'), ('key1', 'val2')],
# [('file1', 'pic.jpg', 'blablabla'), ('file1', 't.txt', 'blobloblo')])
#
# def test_multipartparser(self):
# """Ensure that MultiPartParser can parse multipart/form-data that contains a mix of several files and parameters."""
# post_req = RequestFactory().post('/', self.body, content_type=self.content_type)
# view = View()
# view.request = post_req
# (data, files) = MultiPartParser(view).parse(StringIO(self.body))
# self.assertEqual(data['key1'], 'val1')
# self.assertEqual(files['file1'].read(), 'blablabla')
from StringIO import StringIO
from cgi import parse_qs
from django import forms
from django.test import TestCase
from djangorestframework.parsers import FormParser
from djangorestframework.parsers import XMLParser
import datetime
class Form(forms.Form):
field1 = forms.CharField(max_length=3)
field2 = forms.CharField()
class TestFormParser(TestCase):
def setUp(self):
self.string = "field1=abc&field2=defghijk"
def test_parse(self):
""" Make sure the `QueryDict` works OK """
parser = FormParser(None)
stream = StringIO(self.string)
(data, files) = parser.parse(stream)
self.assertEqual(Form(data).is_valid(), True)
class TestXMLParser(TestCase):
def setUp(self):
self._input = StringIO(
'<?xml version="1.0" encoding="utf-8"?>'
'<root>'
'<field_a>121.0</field_a>'
'<field_b>dasd</field_b>'
'<field_c></field_c>'
'<field_d>2011-12-25 12:45:00</field_d>'
'</root>'
)
self._data = {
'field_a': 121,
'field_b': 'dasd',
'field_c': None,
'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00)
}
self._complex_data_input = StringIO(
'<?xml version="1.0" encoding="utf-8"?>'
'<root>'
'<creation_date>2011-12-25 12:45:00</creation_date>'
'<sub_data_list>'
'<list-item><sub_id>1</sub_id><sub_name>first</sub_name></list-item>'
'<list-item><sub_id>2</sub_id><sub_name>second</sub_name></list-item>'
'</sub_data_list>'
'<name>name</name>'
'</root>'
)
self._complex_data = {
"creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
"name": "name",
"sub_data_list": [
{
"sub_id": 1,
"sub_name": "first"
},
{
"sub_id": 2,
"sub_name": "second"
}
]
}
def test_parse(self):
parser = XMLParser(None)
(data, files) = parser.parse(self._input)
self.assertEqual(data, self._data)
def test_complex_data_parse(self):
parser = XMLParser(None)
(data, files) = parser.parse(self._complex_data_input)
self.assertEqual(data, self._complex_data)

View File

@@ -0,0 +1,411 @@
import re
from django.conf.urls.defaults import patterns, url
from django.test import TestCase
from djangorestframework import status
from djangorestframework.views import View
from djangorestframework.compat import View as DjangoView
from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
XMLRenderer, JSONPRenderer, DocumentingHTMLRenderer
from djangorestframework.parsers import JSONParser, YAMLParser, XMLParser
from djangorestframework.mixins import ResponseMixin
from djangorestframework.response import Response
from StringIO import StringIO
import datetime
from decimal import Decimal
DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent'
RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x
RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x
class RendererA(BaseRenderer):
media_type = 'mock/renderera'
format = "formata"
def render(self, obj=None, media_type=None):
return RENDERER_A_SERIALIZER(obj)
class RendererB(BaseRenderer):
media_type = 'mock/rendererb'
format = "formatb"
def render(self, obj=None, media_type=None):
return RENDERER_B_SERIALIZER(obj)
class MockView(ResponseMixin, DjangoView):
renderers = (RendererA, RendererB)
def get(self, request, **kwargs):
response = Response(DUMMYSTATUS, DUMMYCONTENT)
return self.render(response)
class MockGETView(View):
def get(self, request, **kwargs):
return {'foo': ['bar', 'baz']}
class HTMLView(View):
renderers = (DocumentingHTMLRenderer, )
def get(self, request, **kwargs):
return 'text'
class HTMLView1(View):
renderers = (DocumentingHTMLRenderer, JSONRenderer)
def get(self, request, **kwargs):
return 'text'
urlpatterns = patterns('',
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderers=[RendererA, RendererB])),
url(r'^$', MockView.as_view(renderers=[RendererA, RendererB])),
url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderers=[JSONRenderer, JSONPRenderer])),
url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderers=[JSONPRenderer])),
url(r'^html$', HTMLView.as_view()),
url(r'^html1$', HTMLView1.as_view()),
)
class RendererIntegrationTests(TestCase):
"""
End-to-end testing of renderers using an RendererMixin on a generic view.
"""
urls = 'djangorestframework.tests.renderers'
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/')
self.assertEquals(resp['Content-Type'], RendererA.media_type)
self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
self.assertEquals(resp.status_code, DUMMYSTATUS)
self.assertEquals(resp['Content-Type'], RendererA.media_type)
self.assertEquals(resp.content, '')
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
self.assertEquals(resp['Content-Type'], RendererA.media_type)
self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
self.assertEquals(resp['Content-Type'], RendererA.media_type)
self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_accept_query(self):
"""The '_accept' query string should behave in the same way as the Accept header."""
resp = self.client.get('/?_accept=%s' % RendererB.media_type)
self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
"""If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format)
self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/something.formatb')
self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format,
HTTP_ACCEPT=RendererB.media_type)
self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
def test_conflicting_format_query_and_accept_ignores_accept(self):
"""If a 'format' query is specified that does not match the Accept
header, we should only honor the 'format' query string."""
resp = self.client.get('/?format=%s' % RendererB.format,
HTTP_ACCEPT='dummy')
self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
def test_bla(self):
resp = self.client.get('/?format=formatb',
HTTP_ACCEPT='text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8')
self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
_flat_repr = '{"foo": ["bar", "baz"]}'
_indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}'
def strip_trailing_whitespace(content):
"""
Seems to be some inconsistencies re. trailing whitespace with
different versions of the json lib.
"""
return re.sub(' +\n', '\n', content)
class JSONRendererTests(TestCase):
"""
Tests specific to the JSON Renderer
"""
def test_without_content_type_args(self):
"""
Test basic JSON rendering.
"""
obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer(None)
content = renderer.render(obj, 'application/json')
# Fix failing test case which depends on version of JSON library.
self.assertEquals(content, _flat_repr)
def test_with_content_type_args(self):
"""
Test JSON rendering with additional content type arguments supplied.
"""
obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer(None)
content = renderer.render(obj, 'application/json; indent=2')
self.assertEquals(strip_trailing_whitespace(content), _indented_repr)
def test_render_and_parse(self):
"""
Test rendering and then parsing returns the original object.
IE obj -> render -> parse -> obj.
"""
obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer(None)
parser = JSONParser(None)
content = renderer.render(obj, 'application/json')
(data, files) = parser.parse(StringIO(content))
self.assertEquals(obj, data)
class JSONPRendererTests(TestCase):
"""
Tests specific to the JSONP Renderer
"""
urls = 'djangorestframework.tests.renderers'
def test_without_callback_with_json_renderer(self):
"""
Test JSONP rendering with View JSON Renderer.
"""
resp = self.client.get('/jsonp/jsonrenderer',
HTTP_ACCEPT='application/json-p')
self.assertEquals(resp.status_code, 200)
self.assertEquals(resp['Content-Type'], 'application/json-p')
self.assertEquals(resp.content, 'callback(%s);' % _flat_repr)
def test_without_callback_without_json_renderer(self):
"""
Test JSONP rendering without View JSON Renderer.
"""
resp = self.client.get('/jsonp/nojsonrenderer',
HTTP_ACCEPT='application/json-p')
self.assertEquals(resp.status_code, 200)
self.assertEquals(resp['Content-Type'], 'application/json-p')
self.assertEquals(resp.content, 'callback(%s);' % _flat_repr)
def test_with_callback(self):
"""
Test JSONP rendering with callback function name.
"""
callback_func = 'myjsonpcallback'
resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func,
HTTP_ACCEPT='application/json-p')
self.assertEquals(resp.status_code, 200)
self.assertEquals(resp['Content-Type'], 'application/json-p')
self.assertEquals(resp.content, '%s(%s);' % (callback_func, _flat_repr))
if YAMLRenderer:
_yaml_repr = 'foo: [bar, baz]\n'
class YAMLRendererTests(TestCase):
"""
Tests specific to the JSON Renderer
"""
def test_render(self):
"""
Test basic YAML rendering.
"""
obj = {'foo': ['bar', 'baz']}
renderer = YAMLRenderer(None)
content = renderer.render(obj, 'application/yaml')
self.assertEquals(content, _yaml_repr)
def test_render_and_parse(self):
"""
Test rendering and then parsing returns the original object.
IE obj -> render -> parse -> obj.
"""
obj = {'foo': ['bar', 'baz']}
renderer = YAMLRenderer(None)
parser = YAMLParser(None)
content = renderer.render(obj, 'application/yaml')
(data, files) = parser.parse(StringIO(content))
self.assertEquals(obj, data)
class XMLRendererTestCase(TestCase):
"""
Tests specific to the XML Renderer
"""
_complex_data = {
"creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
"name": "name",
"sub_data_list": [
{
"sub_id": 1,
"sub_name": "first"
},
{
"sub_id": 2,
"sub_name": "second"
}
]
}
def test_render_string(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': 'astring'}, 'application/xml')
self.assertXMLContains(content, '<field>astring</field>')
def test_render_integer(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': 111}, 'application/xml')
self.assertXMLContains(content, '<field>111</field>')
def test_render_datetime(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({
'field': datetime.datetime(2011, 12, 25, 12, 45, 00)
}, 'application/xml')
self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>')
def test_render_float(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': 123.4}, 'application/xml')
self.assertXMLContains(content, '<field>123.4</field>')
def test_render_decimal(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': Decimal('111.2')}, 'application/xml')
self.assertXMLContains(content, '<field>111.2</field>')
def test_render_none(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render({'field': None}, 'application/xml')
self.assertXMLContains(content, '<field></field>')
def test_render_complex_data(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = renderer.render(self._complex_data, 'application/xml')
self.assertXMLContains(content, '<sub_name>first</sub_name>')
self.assertXMLContains(content, '<sub_name>second</sub_name>')
def test_render_and_parse_complex_data(self):
"""
Test XML rendering.
"""
renderer = XMLRenderer(None)
content = StringIO(renderer.render(self._complex_data, 'application/xml'))
parser = XMLParser(None)
complex_data_out, dummy = parser.parse(content)
error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out))
self.assertEqual(self._complex_data, complex_data_out, error_msg)
def assertXMLContains(self, xml, string):
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
self.assertTrue(xml.endswith('</root>'))
self.assertTrue(string in xml, '%r not in %r' % (string, xml))
class Issue122Tests(TestCase):
"""
Tests that covers #122.
"""
urls = 'djangorestframework.tests.renderers'
def test_only_html_renderer(self):
"""
Test if no infinite recursion occurs.
"""
resp = self.client.get('/html')
def test_html_renderer_is_first(self):
"""
Test if no infinite recursion occurs.
"""
resp = self.client.get('/html1')

View File

@@ -0,0 +1,19 @@
# Right now we expect this test to fail - I'm just going to leave it commented out.
# Looking forward to actually being able to raise ExpectedFailure sometime!
#
#from django.test import TestCase
#from djangorestframework.response import Response
#
#
#class TestResponse(TestCase):
#
# # Interface tests
#
# # This is mainly to remind myself that the Response interface needs to change slightly
# def test_response_interface(self):
# """Ensure the Response interface is as expected."""
# response = Response()
# getattr(response, 'status')
# getattr(response, 'content')
# getattr(response, 'headers')

View File

@@ -0,0 +1,28 @@
from django.conf.urls.defaults import patterns, url
from django.core.urlresolvers import reverse
from django.test import TestCase
from django.utils import simplejson as json
from djangorestframework.views import View
class MockView(View):
"""Mock resource which simply returns a URL, so that we can ensure that reversed URLs are fully qualified"""
permissions = ()
def get(self, request):
return reverse('another')
urlpatterns = patterns('',
url(r'^$', MockView.as_view()),
url(r'^another$', MockView.as_view(), name='another'),
)
class ReverseTests(TestCase):
"""Tests for """
urls = 'djangorestframework.tests.reverse'
def test_reversed_urls_are_fully_qualified(self):
response = self.client.get('/')
self.assertEqual(json.loads(response.content), 'http://testserver/another')

View File

@@ -0,0 +1,139 @@
"""Tests for the resource module"""
from django.db import models
from django.test import TestCase
from django.utils.translation import ugettext_lazy
from djangorestframework.serializer import Serializer
import datetime
import decimal
class TestObjectToData(TestCase):
"""
Tests for the Serializer class.
"""
def setUp(self):
self.serializer = Serializer()
self.serialize = self.serializer.serialize
def test_decimal(self):
"""Decimals need to be converted to a string representation."""
self.assertEquals(self.serialize(decimal.Decimal('1.5')), decimal.Decimal('1.5'))
def test_function(self):
"""Functions with no arguments should be called."""
def foo():
return 1
self.assertEquals(self.serialize(foo), 1)
def test_method(self):
"""Methods with only a ``self`` argument should be called."""
class Foo(object):
def foo(self):
return 1
self.assertEquals(self.serialize(Foo().foo), 1)
def test_datetime(self):
"""datetime objects are left as-is."""
now = datetime.datetime.now()
self.assertEquals(self.serialize(now), now)
def test_dict_method_name_collision(self):
"""dict with key that collides with dict method name"""
self.assertEquals(self.serialize({'items': 'foo'}), {'items': u'foo'})
self.assertEquals(self.serialize({'keys': 'foo'}), {'keys': u'foo'})
self.assertEquals(self.serialize({'values': 'foo'}), {'values': u'foo'})
def test_ugettext_lazy(self):
self.assertEquals(self.serialize(ugettext_lazy('foobar')), u'foobar')
class TestFieldNesting(TestCase):
"""
Test nesting the fields in the Serializer class
"""
def setUp(self):
self.serializer = Serializer()
self.serialize = self.serializer.serialize
class M1(models.Model):
field1 = models.CharField(max_length=256)
field2 = models.CharField(max_length=256)
class M2(models.Model):
field = models.OneToOneField(M1)
class M3(models.Model):
field = models.ForeignKey(M1)
self.m1 = M1(field1='foo', field2='bar')
self.m2 = M2(field=self.m1)
self.m3 = M3(field=self.m1)
def test_tuple_nesting(self):
"""
Test tuple nesting on `fields` attr
"""
class SerializerM2(Serializer):
fields = (('field', ('field1',)),)
class SerializerM3(Serializer):
fields = (('field', ('field2',)),)
self.assertEqual(SerializerM2().serialize(self.m2), {'field': {'field1': u'foo'}})
self.assertEqual(SerializerM3().serialize(self.m3), {'field': {'field2': u'bar'}})
def test_serializer_class_nesting(self):
"""
Test related model serialization
"""
class NestedM2(Serializer):
fields = ('field1', )
class NestedM3(Serializer):
fields = ('field2', )
class SerializerM2(Serializer):
fields = [('field', NestedM2)]
class SerializerM3(Serializer):
fields = [('field', NestedM3)]
self.assertEqual(SerializerM2().serialize(self.m2), {'field': {'field1': u'foo'}})
self.assertEqual(SerializerM3().serialize(self.m3), {'field': {'field2': u'bar'}})
def test_serializer_classname_nesting(self):
"""
Test related model serialization
"""
class SerializerM2(Serializer):
fields = [('field', 'NestedM2')]
class SerializerM3(Serializer):
fields = [('field', 'NestedM3')]
class NestedM2(Serializer):
fields = ('field1', )
class NestedM3(Serializer):
fields = ('field2', )
self.assertEqual(SerializerM2().serialize(self.m2), {'field': {'field1': u'foo'}})
self.assertEqual(SerializerM3().serialize(self.m3), {'field': {'field2': u'bar'}})
def test_serializer_overridden_hook_method(self):
"""
Test serializing a model instance which overrides a class method on the
serializer. Checks for correct behaviour in odd edge case.
"""
class SerializerM2(Serializer):
fields = ('overridden', )
def overridden(self):
return False
self.m2.overridden = True
self.assertEqual(SerializerM2().serialize_model(self.m2),
{'overridden': True})

View File

@@ -0,0 +1,12 @@
"""Tests for the status module"""
from django.test import TestCase
from djangorestframework import status
class TestStatus(TestCase):
"""Simple sanity test to check the status module"""
def test_status(self):
"""Ensure the status module is present and correct."""
self.assertEquals(200, status.HTTP_200_OK)
self.assertEquals(404, status.HTTP_404_NOT_FOUND)

View File

@@ -0,0 +1,63 @@
# http://djangosnippets.org/snippets/1011/
from django.conf import settings
from django.core.management import call_command
from django.db.models import loading
from django.test import TestCase
NO_SETTING = ('!', None)
class TestSettingsManager(object):
"""
A class which can modify some Django settings temporarily for a
test and then revert them to their original values later.
Automatically handles resyncing the DB if INSTALLED_APPS is
modified.
"""
def __init__(self):
self._original_settings = {}
def set(self, **kwargs):
for k,v in kwargs.iteritems():
self._original_settings.setdefault(k, getattr(settings, k,
NO_SETTING))
setattr(settings, k, v)
if 'INSTALLED_APPS' in kwargs:
self.syncdb()
def syncdb(self):
loading.cache.loaded = False
call_command('syncdb', verbosity=0)
def revert(self):
for k,v in self._original_settings.iteritems():
if v == NO_SETTING:
delattr(settings, k)
else:
setattr(settings, k, v)
if 'INSTALLED_APPS' in self._original_settings:
self.syncdb()
self._original_settings = {}
class SettingsTestCase(TestCase):
"""
A subclass of the Django TestCase with a settings_manager
attribute which is an instance of TestSettingsManager.
Comes with a tearDown() method that calls
self.settings_manager.revert().
"""
def __init__(self, *args, **kwargs):
super(SettingsTestCase, self).__init__(*args, **kwargs)
self.settings_manager = TestSettingsManager()
def tearDown(self):
self.settings_manager.revert()
class TestModelsTestCase(SettingsTestCase):
def setUp(self, *args, **kwargs):
installed_apps = tuple(settings.INSTALLED_APPS) + ('djangorestframework.tests',)
self.settings_manager.set(INSTALLED_APPS=installed_apps)

View File

@@ -0,0 +1,148 @@
"""
Tests for the throttling implementations in the permissions module.
"""
from django.test import TestCase
from django.contrib.auth.models import User
from django.core.cache import cache
from djangorestframework.compat import RequestFactory
from djangorestframework.views import View
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling
from djangorestframework.resources import FormResource
class MockView(View):
permissions = ( PerUserThrottling, )
throttle = '3/sec'
def get(self, request):
return 'foo'
class MockView_PerViewThrottling(MockView):
permissions = ( PerViewThrottling, )
class MockView_PerResourceThrottling(MockView):
permissions = ( PerResourceThrottling, )
resource = FormResource
class MockView_MinuteThrottling(MockView):
throttle = '3/min'
class ThrottlingTests(TestCase):
urls = 'djangorestframework.tests.throttling'
def setUp(self):
"""
Reset the cache so that no throttles will be active
"""
cache.clear()
self.factory = RequestFactory()
def test_requests_are_throttled(self):
"""
Ensure request rate is limited
"""
request = self.factory.get('/')
for dummy in range(4):
response = MockView.as_view()(request)
self.assertEqual(503, response.status_code)
def set_throttle_timer(self, view, value):
"""
Explicitly set the timer, overriding time.time()
"""
view.permissions[0].timer = lambda self: value
def test_request_throttling_expires(self):
"""
Ensure request rate is limited for a limited duration only
"""
self.set_throttle_timer(MockView, 0)
request = self.factory.get('/')
for dummy in range(4):
response = MockView.as_view()(request)
self.assertEqual(503, response.status_code)
# Advance the timer by one second
self.set_throttle_timer(MockView, 1)
response = MockView.as_view()(request)
self.assertEqual(200, response.status_code)
def ensure_is_throttled(self, view, expect):
request = self.factory.get('/')
request.user = User.objects.create(username='a')
for dummy in range(3):
view.as_view()(request)
request.user = User.objects.create(username='b')
response = view.as_view()(request)
self.assertEqual(expect, response.status_code)
def test_request_throttling_is_per_user(self):
"""
Ensure request rate is only limited per user, not globally for
PerUserThrottles
"""
self.ensure_is_throttled(MockView, 200)
def test_request_throttling_is_per_view(self):
"""
Ensure request rate is limited globally per View for PerViewThrottles
"""
self.ensure_is_throttled(MockView_PerViewThrottling, 503)
def test_request_throttling_is_per_resource(self):
"""
Ensure request rate is limited globally per Resource for PerResourceThrottles
"""
self.ensure_is_throttled(MockView_PerResourceThrottling, 503)
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
"""
Ensure the response returns an X-Throttle field with status and next attributes
set properly.
"""
request = self.factory.get('/')
for timer, expect in expected_headers:
self.set_throttle_timer(view, timer)
response = view.as_view()(request)
self.assertEquals(response['X-Throttle'], expect)
def test_seconds_fields(self):
"""
Ensure for second based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView,
((0, 'status=SUCCESS; next=0.33 sec'),
(0, 'status=SUCCESS; next=0.50 sec'),
(0, 'status=SUCCESS; next=1.00 sec'),
(0, 'status=FAILURE; next=1.00 sec')
))
def test_minutes_fields(self):
"""
Ensure for minute based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
((0, 'status=SUCCESS; next=20.00 sec'),
(0, 'status=SUCCESS; next=30.00 sec'),
(0, 'status=SUCCESS; next=60.00 sec'),
(0, 'status=FAILURE; next=60.00 sec')
))
def test_next_rate_remains_constant_if_followed(self):
"""
If a client follows the recommended next request rate,
the throttling rate should stay constant.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
((0, 'status=SUCCESS; next=20.00 sec'),
(20, 'status=SUCCESS; next=20.00 sec'),
(40, 'status=SUCCESS; next=20.00 sec'),
(60, 'status=SUCCESS; next=20.00 sec'),
(80, 'status=SUCCESS; next=20.00 sec')
))

View File

@@ -0,0 +1,325 @@
from django import forms
from django.db import models
from django.test import TestCase
from djangorestframework.resources import FormResource, ModelResource
from djangorestframework.response import ErrorResponse
from djangorestframework.views import View
class TestDisabledValidations(TestCase):
"""Tests on FormValidator with validation disabled by setting form to None"""
def test_disabled_form_validator_returns_content_unchanged(self):
"""If the view's form attribute is None then FormValidator(view).validate_request(content, None)
should just return the content unmodified."""
class DisabledFormResource(FormResource):
form = None
class MockView(View):
resource = DisabledFormResource
view = MockView()
content = {'qwerty': 'uiop'}
self.assertEqual(FormResource(view).validate_request(content, None), content)
def test_disabled_form_validator_get_bound_form_returns_none(self):
"""If the view's form attribute is None on then
FormValidator(view).get_bound_form(content) should just return None."""
class DisabledFormResource(FormResource):
form = None
class MockView(View):
resource = DisabledFormResource
view = MockView()
content = {'qwerty': 'uiop'}
self.assertEqual(FormResource(view).get_bound_form(content), None)
def test_disabled_model_form_validator_returns_content_unchanged(self):
"""If the view's form is None and does not have a Resource with a model set then
ModelFormValidator(view).validate_request(content, None) should just return the content unmodified."""
class DisabledModelFormView(View):
resource = ModelResource
view = DisabledModelFormView()
content = {'qwerty': 'uiop'}
self.assertEqual(ModelResource(view).get_bound_form(content), None)
def test_disabled_model_form_validator_get_bound_form_returns_none(self):
"""If the form attribute is None on FormValidatorMixin then get_bound_form(content) should just return None."""
class DisabledModelFormView(View):
resource = ModelResource
view = DisabledModelFormView()
content = {'qwerty': 'uiop'}
self.assertEqual(ModelResource(view).get_bound_form(content), None)
class TestNonFieldErrors(TestCase):
"""Tests against form validation errors caused by non-field errors. (eg as might be caused by some custom form validation)"""
def test_validate_failed_due_to_non_field_error_returns_appropriate_message(self):
"""If validation fails with a non-field error, ensure the response a non-field error"""
class MockForm(forms.Form):
field1 = forms.CharField(required=False)
field2 = forms.CharField(required=False)
ERROR_TEXT = 'You may not supply both field1 and field2'
def clean(self):
if 'field1' in self.cleaned_data and 'field2' in self.cleaned_data:
raise forms.ValidationError(self.ERROR_TEXT)
return self.cleaned_data
class MockResource(FormResource):
form = MockForm
class MockView(View):
pass
view = MockView()
content = {'field1': 'example1', 'field2': 'example2'}
try:
MockResource(view).validate_request(content, None)
except ErrorResponse, exc:
self.assertEqual(exc.response.raw_content, {'errors': [MockForm.ERROR_TEXT]})
else:
self.fail('ErrorResponse was not raised')
class TestFormValidation(TestCase):
"""Tests which check basic form validation.
Also includes the same set of tests with a ModelFormValidator for which the form has been explicitly set.
(ModelFormValidator should behave as FormValidator if a form is set rather than relying on the default ModelForm)"""
def setUp(self):
class MockForm(forms.Form):
qwerty = forms.CharField(required=True)
class MockFormResource(FormResource):
form = MockForm
class MockModelResource(ModelResource):
form = MockForm
class MockFormView(View):
resource = MockFormResource
class MockModelFormView(View):
resource = MockModelResource
self.MockFormResource = MockFormResource
self.MockModelResource = MockModelResource
self.MockFormView = MockFormView
self.MockModelFormView = MockModelFormView
def validation_returns_content_unchanged_if_already_valid_and_clean(self, validator):
"""If the content is already valid and clean then validate(content) should just return the content unmodified."""
content = {'qwerty': 'uiop'}
self.assertEqual(validator.validate_request(content, None), content)
def validation_failure_raises_response_exception(self, validator):
"""If form validation fails a ResourceException 400 (Bad Request) should be raised."""
content = {}
self.assertRaises(ErrorResponse, validator.validate_request, content, None)
def validation_does_not_allow_extra_fields_by_default(self, validator):
"""If some (otherwise valid) content includes fields that are not in the form then validation should fail.
It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
broken clients more easily (eg submitting content with a misnamed field)"""
content = {'qwerty': 'uiop', 'extra': 'extra'}
self.assertRaises(ErrorResponse, validator.validate_request, content, None)
def validation_allows_extra_fields_if_explicitly_set(self, validator):
"""If we include an allowed_extra_fields paramater on _validate, then allow fields with those names."""
content = {'qwerty': 'uiop', 'extra': 'extra'}
validator._validate(content, None, allowed_extra_fields=('extra',))
def validation_allows_unknown_fields_if_explicitly_allowed(self, validator):
"""If we set ``unknown_form_fields`` on the form resource, then don't
raise errors on unexpected request data"""
content = {'qwerty': 'uiop', 'extra': 'extra'}
validator.allow_unknown_form_fields = True
self.assertEqual({'qwerty': u'uiop'},
validator.validate_request(content, None),
"Resource didn't accept unknown fields.")
validator.allow_unknown_form_fields = False
def validation_does_not_require_extra_fields_if_explicitly_set(self, validator):
"""If we include an allowed_extra_fields paramater on _validate, then do not fail if we do not have fields with those names."""
content = {'qwerty': 'uiop'}
self.assertEqual(validator._validate(content, None, allowed_extra_fields=('extra',)), content)
def validation_failed_due_to_no_content_returns_appropriate_message(self, validator):
"""If validation fails due to no content, ensure the response contains a single non-field error"""
content = {}
try:
validator.validate_request(content, None)
except ErrorResponse, exc:
self.assertEqual(exc.response.raw_content, {'field_errors': {'qwerty': ['This field is required.']}})
else:
self.fail('ResourceException was not raised')
def validation_failed_due_to_field_error_returns_appropriate_message(self, validator):
"""If validation fails due to a field error, ensure the response contains a single field error"""
content = {'qwerty': ''}
try:
validator.validate_request(content, None)
except ErrorResponse, exc:
self.assertEqual(exc.response.raw_content, {'field_errors': {'qwerty': ['This field is required.']}})
else:
self.fail('ResourceException was not raised')
def validation_failed_due_to_invalid_field_returns_appropriate_message(self, validator):
"""If validation fails due to an invalid field, ensure the response contains a single field error"""
content = {'qwerty': 'uiop', 'extra': 'extra'}
try:
validator.validate_request(content, None)
except ErrorResponse, exc:
self.assertEqual(exc.response.raw_content, {'field_errors': {'extra': ['This field does not exist.']}})
else:
self.fail('ResourceException was not raised')
def validation_failed_due_to_multiple_errors_returns_appropriate_message(self, validator):
"""If validation for multiple reasons, ensure the response contains each error"""
content = {'qwerty': '', 'extra': 'extra'}
try:
validator.validate_request(content, None)
except ErrorResponse, exc:
self.assertEqual(exc.response.raw_content, {'field_errors': {'qwerty': ['This field is required.'],
'extra': ['This field does not exist.']}})
else:
self.fail('ResourceException was not raised')
# Tests on FormResource
def test_form_validation_returns_content_unchanged_if_already_valid_and_clean(self):
validator = self.MockFormResource(self.MockFormView())
self.validation_returns_content_unchanged_if_already_valid_and_clean(validator)
def test_form_validation_failure_raises_response_exception(self):
validator = self.MockFormResource(self.MockFormView())
self.validation_failure_raises_response_exception(validator)
def test_validation_does_not_allow_extra_fields_by_default(self):
validator = self.MockFormResource(self.MockFormView())
self.validation_does_not_allow_extra_fields_by_default(validator)
def test_validation_allows_extra_fields_if_explicitly_set(self):
validator = self.MockFormResource(self.MockFormView())
self.validation_allows_extra_fields_if_explicitly_set(validator)
def test_validation_allows_unknown_fields_if_explicitly_allowed(self):
validator = self.MockFormResource(self.MockFormView())
self.validation_allows_unknown_fields_if_explicitly_allowed(validator)
def test_validation_does_not_require_extra_fields_if_explicitly_set(self):
validator = self.MockFormResource(self.MockFormView())
self.validation_does_not_require_extra_fields_if_explicitly_set(validator)
def test_validation_failed_due_to_no_content_returns_appropriate_message(self):
validator = self.MockFormResource(self.MockFormView())
self.validation_failed_due_to_no_content_returns_appropriate_message(validator)
def test_validation_failed_due_to_field_error_returns_appropriate_message(self):
validator = self.MockFormResource(self.MockFormView())
self.validation_failed_due_to_field_error_returns_appropriate_message(validator)
def test_validation_failed_due_to_invalid_field_returns_appropriate_message(self):
validator = self.MockFormResource(self.MockFormView())
self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator)
def test_validation_failed_due_to_multiple_errors_returns_appropriate_message(self):
validator = self.MockFormResource(self.MockFormView())
self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator)
# Same tests on ModelResource
def test_modelform_validation_returns_content_unchanged_if_already_valid_and_clean(self):
validator = self.MockModelResource(self.MockModelFormView())
self.validation_returns_content_unchanged_if_already_valid_and_clean(validator)
def test_modelform_validation_failure_raises_response_exception(self):
validator = self.MockModelResource(self.MockModelFormView())
self.validation_failure_raises_response_exception(validator)
def test_modelform_validation_does_not_allow_extra_fields_by_default(self):
validator = self.MockModelResource(self.MockModelFormView())
self.validation_does_not_allow_extra_fields_by_default(validator)
def test_modelform_validation_allows_extra_fields_if_explicitly_set(self):
validator = self.MockModelResource(self.MockModelFormView())
self.validation_allows_extra_fields_if_explicitly_set(validator)
def test_modelform_validation_does_not_require_extra_fields_if_explicitly_set(self):
validator = self.MockModelResource(self.MockModelFormView())
self.validation_does_not_require_extra_fields_if_explicitly_set(validator)
def test_modelform_validation_failed_due_to_no_content_returns_appropriate_message(self):
validator = self.MockModelResource(self.MockModelFormView())
self.validation_failed_due_to_no_content_returns_appropriate_message(validator)
def test_modelform_validation_failed_due_to_field_error_returns_appropriate_message(self):
validator = self.MockModelResource(self.MockModelFormView())
self.validation_failed_due_to_field_error_returns_appropriate_message(validator)
def test_modelform_validation_failed_due_to_invalid_field_returns_appropriate_message(self):
validator = self.MockModelResource(self.MockModelFormView())
self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator)
def test_modelform_validation_failed_due_to_multiple_errors_returns_appropriate_message(self):
validator = self.MockModelResource(self.MockModelFormView())
self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator)
class TestModelFormValidator(TestCase):
"""Tests specific to ModelFormValidatorMixin"""
def setUp(self):
"""Create a validator for a model with two fields and a property."""
class MockModel(models.Model):
qwerty = models.CharField(max_length=256)
uiop = models.CharField(max_length=256, blank=True)
@property
def readonly(self):
return 'read only'
class MockResource(ModelResource):
model = MockModel
class MockView(View):
resource = MockResource
self.validator = MockResource(MockView)
def test_property_fields_are_allowed_on_model_forms(self):
"""Validation on ModelForms may include property fields that exist on the Model to be included in the input."""
content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only'}
self.assertEqual(self.validator.validate_request(content, None), content)
def test_property_fields_are_not_required_on_model_forms(self):
"""Validation on ModelForms does not require property fields that exist on the Model to be included in the input."""
content = {'qwerty': 'example', 'uiop': 'example'}
self.assertEqual(self.validator.validate_request(content, None), content)
def test_extra_fields_not_allowed_on_model_forms(self):
"""If some (otherwise valid) content includes fields that are not in the form then validation should fail.
It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
broken clients more easily (eg submitting content with a misnamed field)"""
content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only', 'extra': 'extra'}
self.assertRaises(ErrorResponse, self.validator.validate_request, content, None)
def test_validate_requires_fields_on_model_forms(self):
"""If some (otherwise valid) content includes fields that are not in the form then validation should fail.
It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
broken clients more easily (eg submitting content with a misnamed field)"""
content = {'readonly': 'read only'}
self.assertRaises(ErrorResponse, self.validator.validate_request, content, None)
def test_validate_does_not_require_blankable_fields_on_model_forms(self):
"""Test standard ModelForm validation behaviour - fields with blank=True are not required."""
content = {'qwerty': 'example', 'readonly': 'read only'}
self.validator.validate_request(content, None)
def test_model_form_validator_uses_model_forms(self):
self.assertTrue(isinstance(self.validator.get_bound_form(), forms.ModelForm))

View File

@@ -0,0 +1,137 @@
from django.conf.urls.defaults import patterns, url
from django.http import HttpResponse
from django.test import TestCase
from django.test import Client
from django import forms
from django.db import models
from djangorestframework.views import View
from djangorestframework.parsers import JSONParser
from djangorestframework.resources import ModelResource
from djangorestframework.views import ListOrCreateModelView, InstanceModelView
from StringIO import StringIO
class MockView(View):
"""This is a basic mock view"""
pass
class MockViewFinal(View):
"""View with final() override"""
def final(self, request, response, *args, **kwargs):
return HttpResponse('{"test": "passed"}', content_type="application/json")
class ResourceMockView(View):
"""This is a resource-based mock view"""
class MockForm(forms.Form):
foo = forms.BooleanField(required=False)
bar = forms.IntegerField(help_text='Must be an integer.')
baz = forms.CharField(max_length=32)
form = MockForm
class MockResource(ModelResource):
"""This is a mock model-based resource"""
class MockResourceModel(models.Model):
foo = models.BooleanField()
bar = models.IntegerField(help_text='Must be an integer.')
baz = models.CharField(max_length=32, help_text='Free text. Max length 32 chars.')
model = MockResourceModel
fields = ('foo', 'bar', 'baz')
urlpatterns = patterns('djangorestframework.utils.staticviews',
url(r'^accounts/login$', 'api_login'),
url(r'^accounts/logout$', 'api_logout'),
url(r'^mock/$', MockView.as_view()),
url(r'^mock/final/$', MockViewFinal.as_view()),
url(r'^resourcemock/$', ResourceMockView.as_view()),
url(r'^model/$', ListOrCreateModelView.as_view(resource=MockResource)),
url(r'^model/(?P<pk>[^/]+)/$', InstanceModelView.as_view(resource=MockResource)),
)
class BaseViewTests(TestCase):
"""Test the base view class of djangorestframework"""
urls = 'djangorestframework.tests.views'
def test_view_call_final(self):
response = self.client.options('/mock/final/')
self.assertEqual(response['Content-Type'].split(';')[0], "application/json")
parser = JSONParser(None)
(data, files) = parser.parse(StringIO(response.content))
self.assertEqual(data['test'], 'passed')
def test_options_method_simple_view(self):
response = self.client.options('/mock/')
self._verify_options_response(response,
name='Mock',
description='This is a basic mock view')
def test_options_method_resource_view(self):
response = self.client.options('/resourcemock/')
self._verify_options_response(response,
name='Resource Mock',
description='This is a resource-based mock view',
fields={'foo':'BooleanField',
'bar':'IntegerField',
'baz':'CharField',
})
def test_options_method_model_resource_list_view(self):
response = self.client.options('/model/')
self._verify_options_response(response,
name='Mock List',
description='This is a mock model-based resource',
fields={'foo':'BooleanField',
'bar':'IntegerField',
'baz':'CharField',
})
def test_options_method_model_resource_detail_view(self):
response = self.client.options('/model/0/')
self._verify_options_response(response,
name='Mock Instance',
description='This is a mock model-based resource',
fields={'foo':'BooleanField',
'bar':'IntegerField',
'baz':'CharField',
})
def _verify_options_response(self, response, name, description, fields=None, status=200,
mime_type='application/json'):
self.assertEqual(response.status_code, status)
self.assertEqual(response['Content-Type'].split(';')[0], mime_type)
parser = JSONParser(None)
(data, files) = parser.parse(StringIO(response.content))
self.assertTrue('application/json' in data['renders'])
self.assertEqual(name, data['name'])
self.assertEqual(description, data['description'])
if fields is None:
self.assertFalse(hasattr(data, 'fields'))
else:
self.assertEqual(data['fields'], fields)
class ExtraViewsTests(TestCase):
"""Test the extra views djangorestframework provides"""
urls = 'djangorestframework.tests.views'
def test_login_view(self):
"""Ensure the login view exists"""
response = self.client.get('/accounts/login')
self.assertEqual(response.status_code, 200)
self.assertEqual(response['Content-Type'].split(';')[0], 'text/html')
def test_logout_view(self):
"""Ensure the logout view exists"""
response = self.client.get('/accounts/logout')
self.assertEqual(response.status_code, 200)
self.assertEqual(response['Content-Type'].split(';')[0], 'text/html')
# TODO: Add login/logout behaviour tests

View File

@@ -0,0 +1,6 @@
from django.conf.urls.defaults import patterns
urlpatterns = patterns('djangorestframework.utils.staticviews',
(r'^accounts/login/$', 'api_login'),
(r'^accounts/logout/$', 'api_logout'),
)

View File

@@ -0,0 +1,175 @@
from django.utils.encoding import smart_unicode
from django.utils.xmlutils import SimplerXMLGenerator
from django.core.urlresolvers import resolve
from django.conf import settings
from djangorestframework.compat import StringIO
import re
import xml.etree.ElementTree as ET
#def admin_media_prefix(request):
# """Adds the ADMIN_MEDIA_PREFIX to the request context."""
# return {'ADMIN_MEDIA_PREFIX': settings.ADMIN_MEDIA_PREFIX}
from mediatypes import media_type_matches, is_form_media_type
from mediatypes import add_media_type_param, get_media_type_params, order_by_precedence
MSIE_USER_AGENT_REGEX = re.compile(r'^Mozilla/[0-9]+\.[0-9]+ \([^)]*; MSIE [0-9]+\.[0-9]+[a-z]?;[^)]*\)(?!.* Opera )')
def as_tuple(obj):
"""
Given an object which may be a list/tuple, another object, or None,
return that object in list form.
IE:
If the object is already a list/tuple just return it.
If the object is not None, return it in a list with a single element.
If the object is None return an empty list.
"""
if obj is None:
return ()
elif isinstance(obj, list):
return tuple(obj)
elif isinstance(obj, tuple):
return obj
return (obj,)
def url_resolves(url):
"""
Return True if the given URL is mapped to a view in the urlconf, False otherwise.
"""
try:
resolve(url)
except Exception:
return False
return True
# From http://www.koders.com/python/fidB6E125C586A6F49EAC38992CF3AFDAAE35651975.aspx?s=mdef:xml
#class object_dict(dict):
# """object view of dict, you can
# >>> a = object_dict()
# >>> a.fish = 'fish'
# >>> a['fish']
# 'fish'
# >>> a['water'] = 'water'
# >>> a.water
# 'water'
# >>> a.test = {'value': 1}
# >>> a.test2 = object_dict({'name': 'test2', 'value': 2})
# >>> a.test, a.test2.name, a.test2.value
# (1, 'test2', 2)
# """
# def __init__(self, initd=None):
# if initd is None:
# initd = {}
# dict.__init__(self, initd)
#
# def __getattr__(self, item):
# d = self.__getitem__(item)
# # if value is the only key in object, you can omit it
# if isinstance(d, dict) and 'value' in d and len(d) == 1:
# return d['value']
# else:
# return d
#
# def __setattr__(self, item, value):
# self.__setitem__(item, value)
# From xml2dict
class XML2Dict(object):
def __init__(self):
pass
def _parse_node(self, node):
node_tree = {}
# Save attrs and text, hope there will not be a child with same name
if node.text:
node_tree = node.text
for (k,v) in node.attrib.items():
k,v = self._namespace_split(k, v)
node_tree[k] = v
#Save childrens
for child in node.getchildren():
tag, tree = self._namespace_split(child.tag, self._parse_node(child))
if tag not in node_tree: # the first time, so store it in dict
node_tree[tag] = tree
continue
old = node_tree[tag]
if not isinstance(old, list):
node_tree.pop(tag)
node_tree[tag] = [old] # multi times, so change old dict to a list
node_tree[tag].append(tree) # add the new one
return node_tree
def _namespace_split(self, tag, value):
"""
Split the tag '{http://cs.sfsu.edu/csc867/myscheduler}patients'
ns = http://cs.sfsu.edu/csc867/myscheduler
name = patients
"""
result = re.compile("\{(.*)\}(.*)").search(tag)
if result:
value.namespace, tag = result.groups()
return (tag, value)
def parse(self, file):
"""parse a xml file to a dict"""
f = open(file, 'r')
return self.fromstring(f.read())
def fromstring(self, s):
"""parse a string"""
t = ET.fromstring(s)
unused_root_tag, root_tree = self._namespace_split(t.tag, self._parse_node(t))
return root_tree
def xml2dict(input):
return XML2Dict().fromstring(input)
# Piston:
class XMLRenderer():
def _to_xml(self, xml, data):
if isinstance(data, (list, tuple)):
for item in data:
xml.startElement("list-item", {})
self._to_xml(xml, item)
xml.endElement("list-item")
elif isinstance(data, dict):
for key, value in data.iteritems():
xml.startElement(key, {})
self._to_xml(xml, value)
xml.endElement(key)
elif data is None:
# Don't output any value
pass
else:
xml.characters(smart_unicode(data))
def dict2xml(self, data):
stream = StringIO.StringIO()
xml = SimplerXMLGenerator(stream, "utf-8")
xml.startDocument()
xml.startElement("root", {})
self._to_xml(xml, data)
xml.endElement("root")
xml.endDocument()
return stream.getvalue()
def dict2xml(input):
return XMLRenderer().dict2xml(input)

View File

@@ -0,0 +1,32 @@
from django.core.urlresolvers import resolve
def get_breadcrumbs(url):
"""Given a url returns a list of breadcrumbs, which are each a tuple of (name, url)."""
from djangorestframework.views import View
def breadcrumbs_recursive(url, breadcrumbs_list):
"""Add tuples of (name, url) to the breadcrumbs list, progressively chomping off parts of the url."""
try:
(view, unused_args, unused_kwargs) = resolve(url)
except Exception:
pass
else:
# Check if this is a REST framework view, and if so add it to the breadcrumbs
if isinstance(getattr(view, 'cls_instance', None), View):
breadcrumbs_list.insert(0, (view.cls_instance.get_name(), url))
if url == '':
# All done
return breadcrumbs_list
elif url.endswith('/'):
# Drop trailing slash off the end and continue to try to resolve more breadcrumbs
return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list)
# Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs
return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list)
return breadcrumbs_recursive(url, [])

View File

@@ -0,0 +1,113 @@
"""
Handling of media types, as found in HTTP Content-Type and Accept headers.
See http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7
"""
from django.http.multipartparser import parse_header
def media_type_matches(lhs, rhs):
"""
Returns ``True`` if the media type in the first argument <= the
media type in the second argument. The media types are strings
as described by the HTTP spec.
Valid media type strings include:
'application/json; indent=4'
'application/json'
'text/*'
'*/*'
"""
lhs = _MediaType(lhs)
rhs = _MediaType(rhs)
return lhs.match(rhs)
def is_form_media_type(media_type):
"""
Return True if the media type is a valid form media type as defined by the HTML4 spec.
(NB. HTML5 also adds text/plain to the list of valid form media types, but we don't support this here)
"""
media_type = _MediaType(media_type)
return media_type.full_type == 'application/x-www-form-urlencoded' or \
media_type.full_type == 'multipart/form-data'
def add_media_type_param(media_type, key, val):
"""
Add a key, value parameter to a media type string, and return the new media type string.
"""
media_type = _MediaType(media_type)
media_type.params[key] = val
return str(media_type)
def get_media_type_params(media_type):
"""
Return a dictionary of the parameters on the given media type.
"""
return _MediaType(media_type).params
def order_by_precedence(media_type_lst):
"""
Returns a list of lists of media type strings, ordered by precedence.
Precedence is determined by how specific a media type is:
3. 'type/subtype; param=val'
2. 'type/subtype'
1. 'type/*'
0. '*/*'
"""
ret = [[], [], [], []]
for media_type in media_type_lst:
precedence = _MediaType(media_type).precedence
ret[3 - precedence].append(media_type)
return ret
class _MediaType(object):
def __init__(self, media_type_str):
if media_type_str is None:
media_type_str = ''
self.orig = media_type_str
self.full_type, self.params = parse_header(media_type_str)
self.main_type, sep, self.sub_type = self.full_type.partition('/')
def match(self, other):
"""Return true if this MediaType satisfies the given MediaType."""
for key in self.params.keys():
if key != 'q' and other.params.get(key, None) != self.params.get(key, None):
return False
if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type:
return False
if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type:
return False
return True
@property
def precedence(self):
"""
Return a precedence level from 0-3 for the media type given how specific it is.
"""
if self.main_type == '*':
return 0
elif self.sub_type == '*':
return 1
elif not self.params or self.params.keys() == ['q']:
return 2
return 3
def __str__(self):
return unicode(self).encode('utf-8')
def __unicode__(self):
ret = "%s/%s" % (self.main_type, self.sub_type)
for key, val in self.params.items():
ret += "; %s=%s" % (key, val)
return ret

View File

@@ -0,0 +1,61 @@
from django.contrib.auth.views import *
from django.conf import settings
from django.http import HttpResponse
from django.shortcuts import render_to_response
from django.template import RequestContext
import base64
# BLERGH
# Replicate django.contrib.auth.views.login simply so we don't have get users to update TEMPLATE_CONTEXT_PROCESSORS
# to add ADMIN_MEDIA_PREFIX to the RequestContext. I don't like this but really really want users to not have to
# be making settings changes in order to accomodate django-rest-framework
@csrf_protect
@never_cache
def api_login(request, template_name='djangorestframework/login.html',
redirect_field_name=REDIRECT_FIELD_NAME,
authentication_form=AuthenticationForm):
"""Displays the login form and handles the login action."""
redirect_to = request.REQUEST.get(redirect_field_name, '')
if request.method == "POST":
form = authentication_form(data=request.POST)
if form.is_valid():
# Light security check -- make sure redirect_to isn't garbage.
if not redirect_to or ' ' in redirect_to:
redirect_to = settings.LOGIN_REDIRECT_URL
# Heavier security check -- redirects to http://example.com should
# not be allowed, but things like /view/?param=http://example.com
# should be allowed. This regex checks if there is a '//' *before* a
# question mark.
elif '//' in redirect_to and re.match(r'[^\?]*//', redirect_to):
redirect_to = settings.LOGIN_REDIRECT_URL
# Okay, security checks complete. Log the user in.
auth_login(request, form.get_user())
if request.session.test_cookie_worked():
request.session.delete_test_cookie()
return HttpResponseRedirect(redirect_to)
else:
form = authentication_form(request)
request.session.set_test_cookie()
#current_site = get_current_site(request)
return render_to_response(template_name, {
'form': form,
redirect_field_name: redirect_to,
#'site': current_site,
#'site_name': current_site.name,
'ADMIN_MEDIA_PREFIX': settings.ADMIN_MEDIA_PREFIX,
}, context_instance=RequestContext(request))
def api_logout(request, next_page=None, template_name='djangorestframework/login.html', redirect_field_name=REDIRECT_FIELD_NAME):
return logout(request, next_page, template_name, redirect_field_name)

View File

@@ -0,0 +1,298 @@
"""
The :mod:`views` module provides the Views you will most probably
be subclassing in your implementation.
By setting or modifying class attributes on your view, you change it's predefined behaviour.
"""
import re
from django.core.urlresolvers import set_script_prefix, get_script_prefix
from django.http import HttpResponse
from django.utils.html import escape
from django.utils.safestring import mark_safe
from django.views.decorators.csrf import csrf_exempt
from djangorestframework.compat import View as DjangoView, apply_markdown
from djangorestframework.response import Response, ErrorResponse
from djangorestframework.mixins import *
from djangorestframework import resources, renderers, parsers, authentication, permissions, status
__all__ = (
'View',
'ModelView',
'InstanceModelView',
'ListModelView',
'ListOrCreateModelView'
)
def _remove_trailing_string(content, trailing):
"""
Strip trailing component `trailing` from `content` if it exists.
Used when generating names from view/resource classes.
"""
if content.endswith(trailing) and content != trailing:
return content[:-len(trailing)]
return content
def _remove_leading_indent(content):
"""
Remove leading indent from a block of text.
Used when generating descriptions from docstrings.
"""
whitespace_counts = [len(line) - len(line.lstrip(' '))
for line in content.splitlines()[1:] if line.lstrip()]
# unindent the content if needed
if whitespace_counts:
whitespace_pattern = '^' + (' ' * min(whitespace_counts))
return re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
return content
def _camelcase_to_spaces(content):
"""
Translate 'CamelCaseNames' to 'Camel Case Names'.
Used when generating names from view/resource classes.
"""
camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))'
return re.sub(camelcase_boundry, ' \\1', content).strip()
_resource_classes = (
None,
resources.Resource,
resources.FormResource,
resources.ModelResource
)
class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
"""
Handles incoming requests and maps them to REST operations.
Performs request deserialization, response serialization, authentication and input validation.
"""
resource = None
"""
The resource to use when validating requests and filtering responses,
or `None` to use default behaviour.
"""
renderers = renderers.DEFAULT_RENDERERS
"""
List of renderers the resource can serialize the response with, ordered by preference.
"""
parsers = parsers.DEFAULT_PARSERS
"""
List of parsers the resource can parse the request with.
"""
authentication = (authentication.UserLoggedInAuthentication,
authentication.BasicAuthentication)
"""
List of all authenticating methods to attempt.
"""
permissions = (permissions.FullAnonAccess,)
"""
List of all permissions that must be checked.
"""
@classmethod
def as_view(cls, **initkwargs):
"""
Override the default :meth:`as_view` to store an instance of the view
as an attribute on the callable function. This allows us to discover
information about the view when we do URL reverse lookups.
"""
view = super(View, cls).as_view(**initkwargs)
view.cls_instance = cls(**initkwargs)
return view
@property
def allowed_methods(self):
"""
Return the list of allowed HTTP methods, uppercased.
"""
return [method.upper() for method in self.http_method_names if hasattr(self, method)]
def get_name(self):
"""
Return the resource or view class name for use as this view's name.
Override to customize.
"""
# If this view has a resource that's been overridden, then use that resource for the name
if getattr(self, 'resource', None) not in _resource_classes:
name = self.resource.__name__
name = _remove_trailing_string(name, 'Resource')
name += getattr(self, '_suffix', '')
# If it's a view class with no resource then grok the name from the class name
else:
name = self.__class__.__name__
name = _remove_trailing_string(name, 'View')
return _camelcase_to_spaces(name)
def get_description(self, html=False):
"""
Return the resource or view docstring for use as this view's description.
Override to customize.
"""
description = None
# If this view has a resource that's been overridden,
# then try to use the resource's docstring
if getattr(self, 'resource', None) not in _resource_classes:
description = self.resource.__doc__
# Otherwise use the view docstring
if not description:
description = self.__doc__ or ''
description = _remove_leading_indent(description)
if html:
return self.markup_description(description)
return description
def markup_description(self, description):
if apply_markdown:
description = apply_markdown(description)
else:
description = escape(description).replace('\n', '<br />')
return mark_safe(description)
def http_method_not_allowed(self, request, *args, **kwargs):
"""
Return an HTTP 405 error if an operation is called which does not have a handler method.
"""
raise ErrorResponse(status.HTTP_405_METHOD_NOT_ALLOWED,
{'detail': 'Method \'%s\' not allowed on this resource.' % self.method})
def initial(self, request, *args, **kargs):
"""
Hook for any code that needs to run prior to anything else.
Required if you want to do things like set `request.upload_handlers` before
the authentication and dispatch handling is run.
"""
# Calls to 'reverse' will not be fully qualified unless we set the
# scheme/host/port here.
self.orig_prefix = get_script_prefix()
if not (self.orig_prefix.startswith('http:') or self.orig_prefix.startswith('https:')):
prefix = '%s://%s' % (request.is_secure() and 'https' or 'http', request.get_host())
set_script_prefix(prefix + self.orig_prefix)
def final(self, request, response, *args, **kargs):
"""
Hook for any code that needs to run after everything else in the view.
"""
# Restore script_prefix.
set_script_prefix(self.orig_prefix)
# Always add these headers.
response.headers['Allow'] = ', '.join(self.allowed_methods)
# sample to allow caching using Vary http header
response.headers['Vary'] = 'Authenticate, Accept'
# merge with headers possibly set at some point in the view
response.headers.update(self.headers)
return self.render(response)
def add_header(self, field, value):
"""
Add *field* and *value* to the :attr:`headers` attribute of the :class:`View` class.
"""
self.headers[field] = value
# Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt.
@csrf_exempt
def dispatch(self, request, *args, **kwargs):
self.request = request
self.args = args
self.kwargs = kwargs
self.headers = {}
try:
self.initial(request, *args, **kwargs)
# Authenticate and check request has the relevant permissions
self._check_permissions()
# Get the appropriate handler method
if self.method.lower() in self.http_method_names:
handler = getattr(self, self.method.lower(), self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
response_obj = handler(request, *args, **kwargs)
# Allow return value to be either HttpResponse, Response, or an object, or None
if isinstance(response_obj, HttpResponse):
return response_obj
elif isinstance(response_obj, Response):
response = response_obj
elif response_obj is not None:
response = Response(status.HTTP_200_OK, response_obj)
else:
response = Response(status.HTTP_204_NO_CONTENT)
# Pre-serialize filtering (eg filter complex objects into natively serializable types)
response.cleaned_content = self.filter_response(response.raw_content)
except ErrorResponse, exc:
response = exc.response
return self.final(request, response, *args, **kwargs)
def options(self, request, *args, **kwargs):
response_obj = {
'name': self.get_name(),
'description': self.get_description(),
'renders': self._rendered_media_types,
'parses': self._parsed_media_types,
}
form = self.get_bound_form()
if form is not None:
field_name_types = {}
for name, field in form.fields.iteritems():
field_name_types[name] = field.__class__.__name__
response_obj['fields'] = field_name_types
# Note 'ErrorResponse' is misleading, it's just any response
# that should be rendered and returned immediately, without any
# response filtering.
raise ErrorResponse(status.HTTP_200_OK, response_obj)
class ModelView(View):
"""
A RESTful view that maps to a model in the database.
"""
resource = resources.ModelResource
class InstanceModelView(InstanceMixin, ReadModelMixin, UpdateModelMixin, DeleteModelMixin, ModelView):
"""
A view which provides default operations for read/update/delete against a model instance.
"""
_suffix = 'Instance'
class ListModelView(ListModelMixin, ModelView):
"""
A view which provides default operations for list, against a model in the database.
"""
_suffix = 'List'
class ListOrCreateModelView(ListModelMixin, CreateModelMixin, ModelView):
"""
A view which provides default operations for list and create, against a model in the database.
"""
_suffix = 'List'

View File

@@ -43,6 +43,7 @@ urlpatterns = patterns('',
(r'^share/', include('share.urls')), (r'^share/', include('share.urls')),
(r'^api/', include('api.urls')), (r'^api/', include('api.urls')),
(r'^rest/', include('djangorestframework.urls')),
url(r'^shareadmin/$', share_admin, name='share_admin'), url(r'^shareadmin/$', share_admin, name='share_admin'),
(r'^shareadmin/removeshare/$', repo_remove_share), (r'^shareadmin/removeshare/$', repo_remove_share),
(r'^sharedlink/get/$', get_shared_link), (r'^sharedlink/get/$', get_shared_link),