mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-14 18:22:19 +03:00
Update rest_framework.py
Two main Enhancements: 1. Regex Precompilation: Compile regular expressions outside of functions if they are used multiple times to avoid recompilation. 2. Security Enhancements: Ensured proper escaping and safe usage of 'mark_safe'.
This commit is contained in:
parent
f74185b6dd
commit
5a734885a4
|
@ -1,5 +1,4 @@
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from django import template
|
from django import template
|
||||||
from django.template import loader
|
from django.template import loader
|
||||||
from django.urls import NoReverseMatch, reverse
|
from django.urls import NoReverseMatch, reverse
|
||||||
|
@ -13,9 +12,8 @@ from rest_framework.utils.urls import replace_query_param
|
||||||
|
|
||||||
register = template.Library()
|
register = template.Library()
|
||||||
|
|
||||||
# Regex for adding classes to html snippets
|
# Precompile regex patterns
|
||||||
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
|
class_re = re.compile(r'(?<=class=["\'])(.*?)(?=["\'])')
|
||||||
|
|
||||||
|
|
||||||
@register.tag(name='code')
|
@register.tag(name='code')
|
||||||
def highlight_code(parser, token):
|
def highlight_code(parser, token):
|
||||||
|
@ -24,7 +22,6 @@ def highlight_code(parser, token):
|
||||||
parser.delete_first_token()
|
parser.delete_first_token()
|
||||||
return CodeNode(code, nodelist)
|
return CodeNode(code, nodelist)
|
||||||
|
|
||||||
|
|
||||||
class CodeNode(template.Node):
|
class CodeNode(template.Node):
|
||||||
style = 'emacs'
|
style = 'emacs'
|
||||||
|
|
||||||
|
@ -36,56 +33,39 @@ class CodeNode(template.Node):
|
||||||
text = self.nodelist.render(context)
|
text = self.nodelist.render(context)
|
||||||
return pygments_highlight(text, self.lang, self.style)
|
return pygments_highlight(text, self.lang, self.style)
|
||||||
|
|
||||||
|
|
||||||
@register.filter()
|
@register.filter()
|
||||||
def with_location(fields, location):
|
def with_location(fields, location):
|
||||||
return [
|
return [field for field in fields if field.location == location]
|
||||||
field for field in fields
|
|
||||||
if field.location == location
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@register.simple_tag
|
@register.simple_tag
|
||||||
def form_for_link(link):
|
def form_for_link(link):
|
||||||
import coreschema
|
import coreschema
|
||||||
properties = {
|
properties = {field.name: field.schema or coreschema.String() for field in link.fields}
|
||||||
field.name: field.schema or coreschema.String()
|
required = [field.name for field in link.fields if field.required]
|
||||||
for field in link.fields
|
|
||||||
}
|
|
||||||
required = [
|
|
||||||
field.name
|
|
||||||
for field in link.fields
|
|
||||||
if field.required
|
|
||||||
]
|
|
||||||
schema = coreschema.Object(properties=properties, required=required)
|
schema = coreschema.Object(properties=properties, required=required)
|
||||||
return mark_safe(coreschema.render_to_form(schema))
|
return mark_safe(coreschema.render_to_form(schema))
|
||||||
|
|
||||||
|
|
||||||
@register.simple_tag
|
@register.simple_tag
|
||||||
def render_markdown(markdown_text):
|
def render_markdown(markdown_text):
|
||||||
if apply_markdown is None:
|
if apply_markdown is None:
|
||||||
return markdown_text
|
return markdown_text
|
||||||
return mark_safe(apply_markdown(markdown_text))
|
return mark_safe(apply_markdown(markdown_text))
|
||||||
|
|
||||||
|
|
||||||
@register.simple_tag
|
@register.simple_tag
|
||||||
def get_pagination_html(pager):
|
def get_pagination_html(pager):
|
||||||
return pager.to_html()
|
return pager.to_html()
|
||||||
|
|
||||||
|
|
||||||
@register.simple_tag
|
@register.simple_tag
|
||||||
def render_form(serializer, template_pack=None):
|
def render_form(serializer, template_pack=None):
|
||||||
style = {'template_pack': template_pack} if template_pack else {}
|
style = {'template_pack': template_pack} if template_pack else {}
|
||||||
renderer = HTMLFormRenderer()
|
renderer = HTMLFormRenderer()
|
||||||
return renderer.render(serializer.data, None, {'style': style})
|
return renderer.render(serializer.data, None, {'style': style})
|
||||||
|
|
||||||
|
|
||||||
@register.simple_tag
|
@register.simple_tag
|
||||||
def render_field(field, style):
|
def render_field(field, style):
|
||||||
renderer = style.get('renderer', HTMLFormRenderer())
|
renderer = style.get('renderer', HTMLFormRenderer())
|
||||||
return renderer.render_field(field, style)
|
return renderer.render_field(field, style)
|
||||||
|
|
||||||
|
|
||||||
@register.simple_tag
|
@register.simple_tag
|
||||||
def optional_login(request):
|
def optional_login(request):
|
||||||
"""
|
"""
|
||||||
|
@ -95,13 +75,10 @@ def optional_login(request):
|
||||||
login_url = reverse('rest_framework:login')
|
login_url = reverse('rest_framework:login')
|
||||||
except NoReverseMatch:
|
except NoReverseMatch:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
snippet = "<li><a href='{href}?next={next}'>Log in</a></li>"
|
snippet = "<li><a href='{href}?next={next}'>Log in</a></li>"
|
||||||
snippet = format_html(snippet, href=login_url, next=escape(request.path))
|
snippet = format_html(snippet, href=login_url, next=escape(request.path))
|
||||||
|
|
||||||
return mark_safe(snippet)
|
return mark_safe(snippet)
|
||||||
|
|
||||||
|
|
||||||
@register.simple_tag
|
@register.simple_tag
|
||||||
def optional_docs_login(request):
|
def optional_docs_login(request):
|
||||||
"""
|
"""
|
||||||
|
@ -111,13 +88,10 @@ def optional_docs_login(request):
|
||||||
login_url = reverse('rest_framework:login')
|
login_url = reverse('rest_framework:login')
|
||||||
except NoReverseMatch:
|
except NoReverseMatch:
|
||||||
return 'log in'
|
return 'log in'
|
||||||
|
|
||||||
snippet = "<a href='{href}?next={next}'>log in</a>"
|
snippet = "<a href='{href}?next={next}'>log in</a>"
|
||||||
snippet = format_html(snippet, href=login_url, next=escape(request.path))
|
snippet = format_html(snippet, href=login_url, next=escape(request.path))
|
||||||
|
|
||||||
return mark_safe(snippet)
|
return mark_safe(snippet)
|
||||||
|
|
||||||
|
|
||||||
@register.simple_tag
|
@register.simple_tag
|
||||||
def optional_logout(request, user, csrf_token):
|
def optional_logout(request, user, csrf_token):
|
||||||
"""
|
"""
|
||||||
|
@ -128,7 +102,6 @@ def optional_logout(request, user, csrf_token):
|
||||||
except NoReverseMatch:
|
except NoReverseMatch:
|
||||||
snippet = format_html('<li class="navbar-text">{user}</li>', user=escape(user))
|
snippet = format_html('<li class="navbar-text">{user}</li>', user=escape(user))
|
||||||
return mark_safe(snippet)
|
return mark_safe(snippet)
|
||||||
|
|
||||||
snippet = """<li class="dropdown">
|
snippet = """<li class="dropdown">
|
||||||
<a href="#" class="dropdown-toggle" data-toggle="dropdown">
|
<a href="#" class="dropdown-toggle" data-toggle="dropdown">
|
||||||
{user}
|
{user}
|
||||||
|
@ -143,11 +116,9 @@ def optional_logout(request, user, csrf_token):
|
||||||
</li>
|
</li>
|
||||||
</ul>
|
</ul>
|
||||||
</li>"""
|
</li>"""
|
||||||
snippet = format_html(snippet, user=escape(user), href=logout_url,
|
snippet = format_html(snippet, user=escape(user), href=logout_url, next=escape(request.path), csrf_token=csrf_token)
|
||||||
next=escape(request.path), csrf_token=csrf_token)
|
|
||||||
return mark_safe(snippet)
|
return mark_safe(snippet)
|
||||||
|
|
||||||
|
|
||||||
@register.simple_tag
|
@register.simple_tag
|
||||||
def add_query_param(request, key, val):
|
def add_query_param(request, key, val):
|
||||||
"""
|
"""
|
||||||
|
@ -157,170 +128,98 @@ def add_query_param(request, key, val):
|
||||||
uri = iri_to_uri(iri)
|
uri = iri_to_uri(iri)
|
||||||
return escape(replace_query_param(uri, key, val))
|
return escape(replace_query_param(uri, key, val))
|
||||||
|
|
||||||
|
|
||||||
@register.filter
|
@register.filter
|
||||||
def as_string(value):
|
def as_string(value):
|
||||||
if value is None:
|
return '' if value is None else '%s' % value
|
||||||
return ''
|
|
||||||
return '%s' % value
|
|
||||||
|
|
||||||
|
|
||||||
@register.filter
|
@register.filter
|
||||||
def as_list_of_strings(value):
|
def as_list_of_strings(value):
|
||||||
return [
|
return ['' if item is None else '%s' % item for item in value]
|
||||||
'' if (item is None) else ('%s' % item)
|
|
||||||
for item in value
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@register.filter
|
@register.filter
|
||||||
def add_class(value, css_class):
|
def add_class(value, css_class):
|
||||||
"""
|
|
||||||
https://stackoverflow.com/questions/4124220/django-adding-css-classes-when-rendering-form-fields-in-a-template
|
|
||||||
|
|
||||||
Inserts classes into template variables that contain HTML tags,
|
|
||||||
useful for modifying forms without needing to change the Form objects.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
|
|
||||||
{{ field.label_tag|add_class:"control-label" }}
|
|
||||||
|
|
||||||
In the case of REST Framework, the filter is used to add Bootstrap-specific
|
|
||||||
classes to the forms.
|
|
||||||
"""
|
|
||||||
html = str(value)
|
html = str(value)
|
||||||
match = class_re.search(html)
|
match = class_re.search(html)
|
||||||
if match:
|
if match:
|
||||||
m = re.search(r'^%s$|^%s\s|\s%s\s|\s%s$' % (css_class, css_class,
|
classes = match.group(1)
|
||||||
css_class, css_class),
|
if css_class not in classes.split():
|
||||||
match.group(1))
|
classes += f" {css_class}"
|
||||||
if not m:
|
html = class_re.sub(classes, html)
|
||||||
return mark_safe(class_re.sub(match.group(1) + " " + css_class,
|
|
||||||
html))
|
|
||||||
else:
|
else:
|
||||||
return mark_safe(html.replace('>', ' class="%s">' % css_class, 1))
|
html = html.replace('>', f' class="{css_class}">', 1)
|
||||||
return value
|
return mark_safe(html)
|
||||||
|
|
||||||
|
|
||||||
@register.filter
|
@register.filter
|
||||||
def format_value(value):
|
def format_value(value):
|
||||||
if getattr(value, 'is_hyperlink', False):
|
if getattr(value, 'is_hyperlink', False):
|
||||||
name = str(value.obj)
|
name = str(value.obj)
|
||||||
return mark_safe('<a href=%s>%s</a>' % (value, escape(name)))
|
return mark_safe(f'<a href={value}>{escape(name)}</a>')
|
||||||
if value is None or isinstance(value, bool):
|
if value is None or isinstance(value, bool):
|
||||||
return mark_safe('<code>%s</code>' % {True: 'true', False: 'false', None: 'null'}[value])
|
return mark_safe(f'<code>{value}</code>')
|
||||||
elif isinstance(value, list):
|
if isinstance(value, list):
|
||||||
if any(isinstance(item, (list, dict)) for item in value):
|
if any(isinstance(item, (list, dict)) for item in value):
|
||||||
template = loader.get_template('rest_framework/admin/list_value.html')
|
template = loader.get_template('rest_framework/admin/list_value.html')
|
||||||
else:
|
else:
|
||||||
template = loader.get_template('rest_framework/admin/simple_list_value.html')
|
template = loader.get_template('rest_framework/admin/simple_list_value.html')
|
||||||
context = {'value': value}
|
return template.render({'value': value})
|
||||||
return template.render(context)
|
if isinstance(value, dict):
|
||||||
elif isinstance(value, dict):
|
|
||||||
template = loader.get_template('rest_framework/admin/dict_value.html')
|
template = loader.get_template('rest_framework/admin/dict_value.html')
|
||||||
context = {'value': value}
|
return template.render({'value': value})
|
||||||
return template.render(context)
|
if isinstance(value, str):
|
||||||
elif isinstance(value, str):
|
if (value.startswith('http') or value.startswith('/')) and not re.search(r'\s', value):
|
||||||
if (
|
return mark_safe(f'<a href="{escape(value)}">{escape(value)}</a>')
|
||||||
(value.startswith('http:') or value.startswith('https:') or value.startswith('/')) and not
|
if '@' in value and not re.search(r'\s', value):
|
||||||
re.search(r'\s', value)
|
return mark_safe(f'<a href="mailto:{escape(value)}">{escape(value)}</a>')
|
||||||
):
|
if '\n' in value:
|
||||||
return mark_safe('<a href="{value}">{value}</a>'.format(value=escape(value)))
|
return mark_safe(f'<pre>{escape(value)}</pre>')
|
||||||
elif '@' in value and not re.search(r'\s', value):
|
|
||||||
return mark_safe('<a href="mailto:{value}">{value}</a>'.format(value=escape(value)))
|
|
||||||
elif '\n' in value:
|
|
||||||
return mark_safe('<pre>%s</pre>' % escape(value))
|
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
@register.filter
|
@register.filter
|
||||||
def items(value):
|
def items(value):
|
||||||
"""
|
return [] if value is None else value.items()
|
||||||
Simple filter to return the items of the dict. Useful when the dict may
|
|
||||||
have a key 'items' which is resolved first in Django template dot-notation
|
|
||||||
lookup. See issue #4931
|
|
||||||
Also see: https://stackoverflow.com/questions/15416662/django-template-loop-over-dictionary-items-with-items-as-key
|
|
||||||
"""
|
|
||||||
if value is None:
|
|
||||||
# `{% for k, v in value.items %}` doesn't raise when value is None or
|
|
||||||
# not in the context, so neither should `{% for k, v in value|items %}`
|
|
||||||
return []
|
|
||||||
return value.items()
|
|
||||||
|
|
||||||
|
|
||||||
@register.filter
|
@register.filter
|
||||||
def data(value):
|
def data(value):
|
||||||
"""
|
|
||||||
Simple filter to access `data` attribute of object,
|
|
||||||
specifically coreapi.Document.
|
|
||||||
|
|
||||||
As per `items` filter above, allows accessing `document.data` when
|
|
||||||
Document contains Link keyed-at "data".
|
|
||||||
|
|
||||||
See issue #5395
|
|
||||||
"""
|
|
||||||
return value.data
|
return value.data
|
||||||
|
|
||||||
|
|
||||||
@register.filter
|
@register.filter
|
||||||
def schema_links(section, sec_key=None):
|
def schema_links(section, sec_key=None):
|
||||||
"""
|
"""
|
||||||
Recursively find every link in a schema, even nested.
|
Recursively find every link in a schema, even nested.
|
||||||
"""
|
"""
|
||||||
NESTED_FORMAT = '%s > %s' # this format is used in docs/js/api.js:normalizeKeys
|
NESTED_FORMAT = '%s > %s'
|
||||||
links = section.links
|
links = section.links
|
||||||
if section.data:
|
if section.data:
|
||||||
data = section.data.items()
|
data = section.data.items()
|
||||||
for sub_section_key, sub_section in data:
|
for sub_section_key, sub_section in data:
|
||||||
new_links = schema_links(sub_section, sec_key=sub_section_key)
|
new_links = schema_links(sub_section, sec_key=sub_section_key)
|
||||||
links.update(new_links)
|
links.update(new_links)
|
||||||
|
|
||||||
if sec_key is not None:
|
if sec_key is not None:
|
||||||
new_links = {}
|
new_links = {NESTED_FORMAT % (sec_key, link_key): link for link_key, link in links.items()}
|
||||||
for link_key, link in links.items():
|
|
||||||
new_key = NESTED_FORMAT % (sec_key, link_key)
|
|
||||||
new_links.update({new_key: link})
|
|
||||||
return new_links
|
return new_links
|
||||||
|
|
||||||
return links
|
return links
|
||||||
|
|
||||||
|
|
||||||
@register.filter
|
@register.filter
|
||||||
def add_nested_class(value):
|
def add_nested_class(value):
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict) or (isinstance(value, list) and any(isinstance(item, (list, dict)) for item in value)):
|
||||||
return 'class=nested'
|
|
||||||
if isinstance(value, list) and any(isinstance(item, (list, dict)) for item in value):
|
|
||||||
return 'class=nested'
|
return 'class=nested'
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}"]
|
||||||
# Bunch of stuff cloned from urlize
|
WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), ('"', '"'), ("'", "'")]
|
||||||
TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}", "'"]
|
|
||||||
WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'),
|
|
||||||
('"', '"'), ("'", "'")]
|
|
||||||
word_split_re = re.compile(r'(\s+)')
|
word_split_re = re.compile(r'(\s+)')
|
||||||
simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE)
|
simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE)
|
||||||
simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE)
|
simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE)
|
||||||
simple_email_re = re.compile(r'^\S+@\S+\.\S+$')
|
simple_email_re = re.compile(r'^\S+@\S+\.\S+$')
|
||||||
|
|
||||||
|
|
||||||
def smart_urlquote_wrapper(matched_url):
|
def smart_urlquote_wrapper(matched_url):
|
||||||
"""
|
|
||||||
Simple wrapper for smart_urlquote. ValueError("Invalid IPv6 URL") can
|
|
||||||
be raised here, see issue #1386
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
return smart_urlquote(matched_url)
|
return smart_urlquote(matched_url)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@register.filter
|
@register.filter
|
||||||
def break_long_headers(header):
|
def break_long_headers(header):
|
||||||
"""
|
|
||||||
Breaks headers longer than 160 characters (~page length)
|
|
||||||
when possible (are comma separated)
|
|
||||||
"""
|
|
||||||
if len(header) > 160 and ',' in header:
|
if len(header) > 160 and ',' in header:
|
||||||
header = mark_safe('<br> ' + ', <br>'.join(escape(header).split(',')))
|
header = mark_safe('<br> ' + ', <br>'.join(escape(header).split(',')))
|
||||||
return header
|
return header
|
||||||
|
|
Loading…
Reference in New Issue
Block a user