Update rest_framework.py

Two main Enhancements:

1. Regex Precompilation:
Compile regular expressions outside of functions if they are used multiple times to avoid recompilation.

2. Security Enhancements:
Ensured proper escaping and safe usage of 'mark_safe'.
This commit is contained in:
SOHAIL AHMAD 2024-07-16 04:11:41 -07:00 committed by GitHub
parent f74185b6dd
commit 5a734885a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,4 @@
import re
from django import template
from django.template import loader
from django.urls import NoReverseMatch, reverse
@ -13,9 +12,8 @@ from rest_framework.utils.urls import replace_query_param
register = template.Library()
# Regex for adding classes to html snippets
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
# Precompile regex patterns
class_re = re.compile(r'(?<=class=["\'])(.*?)(?=["\'])')
@register.tag(name='code')
def highlight_code(parser, token):
@ -24,7 +22,6 @@ def highlight_code(parser, token):
parser.delete_first_token()
return CodeNode(code, nodelist)
class CodeNode(template.Node):
style = 'emacs'
@ -36,56 +33,39 @@ class CodeNode(template.Node):
text = self.nodelist.render(context)
return pygments_highlight(text, self.lang, self.style)
@register.filter()
def with_location(fields, location):
return [
field for field in fields
if field.location == location
]
return [field for field in fields if field.location == location]
@register.simple_tag
def form_for_link(link):
import coreschema
properties = {
field.name: field.schema or coreschema.String()
for field in link.fields
}
required = [
field.name
for field in link.fields
if field.required
]
properties = {field.name: field.schema or coreschema.String() for field in link.fields}
required = [field.name for field in link.fields if field.required]
schema = coreschema.Object(properties=properties, required=required)
return mark_safe(coreschema.render_to_form(schema))
@register.simple_tag
def render_markdown(markdown_text):
if apply_markdown is None:
return markdown_text
return mark_safe(apply_markdown(markdown_text))
@register.simple_tag
def get_pagination_html(pager):
return pager.to_html()
@register.simple_tag
def render_form(serializer, template_pack=None):
style = {'template_pack': template_pack} if template_pack else {}
renderer = HTMLFormRenderer()
return renderer.render(serializer.data, None, {'style': style})
@register.simple_tag
def render_field(field, style):
renderer = style.get('renderer', HTMLFormRenderer())
return renderer.render_field(field, style)
@register.simple_tag
def optional_login(request):
"""
@ -95,13 +75,10 @@ def optional_login(request):
login_url = reverse('rest_framework:login')
except NoReverseMatch:
return ''
snippet = "<li><a href='{href}?next={next}'>Log in</a></li>"
snippet = format_html(snippet, href=login_url, next=escape(request.path))
return mark_safe(snippet)
@register.simple_tag
def optional_docs_login(request):
"""
@ -111,13 +88,10 @@ def optional_docs_login(request):
login_url = reverse('rest_framework:login')
except NoReverseMatch:
return 'log in'
snippet = "<a href='{href}?next={next}'>log in</a>"
snippet = format_html(snippet, href=login_url, next=escape(request.path))
return mark_safe(snippet)
@register.simple_tag
def optional_logout(request, user, csrf_token):
"""
@ -128,7 +102,6 @@ def optional_logout(request, user, csrf_token):
except NoReverseMatch:
snippet = format_html('<li class="navbar-text">{user}</li>', user=escape(user))
return mark_safe(snippet)
snippet = """<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">
{user}
@ -143,11 +116,9 @@ def optional_logout(request, user, csrf_token):
</li>
</ul>
</li>"""
snippet = format_html(snippet, user=escape(user), href=logout_url,
next=escape(request.path), csrf_token=csrf_token)
snippet = format_html(snippet, user=escape(user), href=logout_url, next=escape(request.path), csrf_token=csrf_token)
return mark_safe(snippet)
@register.simple_tag
def add_query_param(request, key, val):
"""
@ -157,170 +128,98 @@ def add_query_param(request, key, val):
uri = iri_to_uri(iri)
return escape(replace_query_param(uri, key, val))
@register.filter
def as_string(value):
if value is None:
return ''
return '%s' % value
return '' if value is None else '%s' % value
@register.filter
def as_list_of_strings(value):
return [
'' if (item is None) else ('%s' % item)
for item in value
]
return ['' if item is None else '%s' % item for item in value]
@register.filter
def add_class(value, css_class):
"""
https://stackoverflow.com/questions/4124220/django-adding-css-classes-when-rendering-form-fields-in-a-template
Inserts classes into template variables that contain HTML tags,
useful for modifying forms without needing to change the Form objects.
Usage:
{{ field.label_tag|add_class:"control-label" }}
In the case of REST Framework, the filter is used to add Bootstrap-specific
classes to the forms.
"""
html = str(value)
match = class_re.search(html)
if match:
m = re.search(r'^%s$|^%s\s|\s%s\s|\s%s$' % (css_class, css_class,
css_class, css_class),
match.group(1))
if not m:
return mark_safe(class_re.sub(match.group(1) + " " + css_class,
html))
classes = match.group(1)
if css_class not in classes.split():
classes += f" {css_class}"
html = class_re.sub(classes, html)
else:
return mark_safe(html.replace('>', ' class="%s">' % css_class, 1))
return value
html = html.replace('>', f' class="{css_class}">', 1)
return mark_safe(html)
@register.filter
def format_value(value):
if getattr(value, 'is_hyperlink', False):
name = str(value.obj)
return mark_safe('<a href=%s>%s</a>' % (value, escape(name)))
return mark_safe(f'<a href={value}>{escape(name)}</a>')
if value is None or isinstance(value, bool):
return mark_safe('<code>%s</code>' % {True: 'true', False: 'false', None: 'null'}[value])
elif isinstance(value, list):
return mark_safe(f'<code>{value}</code>')
if isinstance(value, list):
if any(isinstance(item, (list, dict)) for item in value):
template = loader.get_template('rest_framework/admin/list_value.html')
else:
template = loader.get_template('rest_framework/admin/simple_list_value.html')
context = {'value': value}
return template.render(context)
elif isinstance(value, dict):
return template.render({'value': value})
if isinstance(value, dict):
template = loader.get_template('rest_framework/admin/dict_value.html')
context = {'value': value}
return template.render(context)
elif isinstance(value, str):
if (
(value.startswith('http:') or value.startswith('https:') or value.startswith('/')) and not
re.search(r'\s', value)
):
return mark_safe('<a href="{value}">{value}</a>'.format(value=escape(value)))
elif '@' in value and not re.search(r'\s', value):
return mark_safe('<a href="mailto:{value}">{value}</a>'.format(value=escape(value)))
elif '\n' in value:
return mark_safe('<pre>%s</pre>' % escape(value))
return template.render({'value': value})
if isinstance(value, str):
if (value.startswith('http') or value.startswith('/')) and not re.search(r'\s', value):
return mark_safe(f'<a href="{escape(value)}">{escape(value)}</a>')
if '@' in value and not re.search(r'\s', value):
return mark_safe(f'<a href="mailto:{escape(value)}">{escape(value)}</a>')
if '\n' in value:
return mark_safe(f'<pre>{escape(value)}</pre>')
return str(value)
@register.filter
def items(value):
"""
Simple filter to return the items of the dict. Useful when the dict may
have a key 'items' which is resolved first in Django template dot-notation
lookup. See issue #4931
Also see: https://stackoverflow.com/questions/15416662/django-template-loop-over-dictionary-items-with-items-as-key
"""
if value is None:
# `{% for k, v in value.items %}` doesn't raise when value is None or
# not in the context, so neither should `{% for k, v in value|items %}`
return []
return value.items()
return [] if value is None else value.items()
@register.filter
def data(value):
"""
Simple filter to access `data` attribute of object,
specifically coreapi.Document.
As per `items` filter above, allows accessing `document.data` when
Document contains Link keyed-at "data".
See issue #5395
"""
return value.data
@register.filter
def schema_links(section, sec_key=None):
"""
Recursively find every link in a schema, even nested.
"""
NESTED_FORMAT = '%s > %s' # this format is used in docs/js/api.js:normalizeKeys
NESTED_FORMAT = '%s > %s'
links = section.links
if section.data:
data = section.data.items()
for sub_section_key, sub_section in data:
new_links = schema_links(sub_section, sec_key=sub_section_key)
links.update(new_links)
if sec_key is not None:
new_links = {}
for link_key, link in links.items():
new_key = NESTED_FORMAT % (sec_key, link_key)
new_links.update({new_key: link})
new_links = {NESTED_FORMAT % (sec_key, link_key): link for link_key, link in links.items()}
return new_links
return links
@register.filter
def add_nested_class(value):
if isinstance(value, dict):
return 'class=nested'
if isinstance(value, list) and any(isinstance(item, (list, dict)) for item in value):
if isinstance(value, dict) or (isinstance(value, list) and any(isinstance(item, (list, dict)) for item in value)):
return 'class=nested'
return ''
# Bunch of stuff cloned from urlize
TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}", "'"]
WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('&lt;', '&gt;'),
('"', '"'), ("'", "'")]
TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}"]
WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('&lt;', '&gt;'), ('"', '"'), ("'", "'")]
word_split_re = re.compile(r'(\s+)')
simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE)
simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE)
simple_email_re = re.compile(r'^\S+@\S+\.\S+$')
def smart_urlquote_wrapper(matched_url):
"""
Simple wrapper for smart_urlquote. ValueError("Invalid IPv6 URL") can
be raised here, see issue #1386
"""
try:
return smart_urlquote(matched_url)
except ValueError:
return None
@register.filter
def break_long_headers(header):
"""
Breaks headers longer than 160 characters (~page length)
when possible (are comma separated)
"""
if len(header) > 160 and ',' in header:
header = mark_safe('<br> ' + ', <br>'.join(escape(header).split(',')))
return header