diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index dba8153b1..39a10eb21 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -1,5 +1,4 @@ import re - from django import template from django.template import loader from django.urls import NoReverseMatch, reverse @@ -13,9 +12,8 @@ from rest_framework.utils.urls import replace_query_param register = template.Library() -# Regex for adding classes to html snippets -class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') - +# Precompile regex patterns +class_re = re.compile(r'(?<=class=["\'])(.*?)(?=["\'])') @register.tag(name='code') def highlight_code(parser, token): @@ -24,7 +22,6 @@ def highlight_code(parser, token): parser.delete_first_token() return CodeNode(code, nodelist) - class CodeNode(template.Node): style = 'emacs' @@ -36,56 +33,39 @@ class CodeNode(template.Node): text = self.nodelist.render(context) return pygments_highlight(text, self.lang, self.style) - @register.filter() def with_location(fields, location): - return [ - field for field in fields - if field.location == location - ] - + return [field for field in fields if field.location == location] @register.simple_tag def form_for_link(link): import coreschema - properties = { - field.name: field.schema or coreschema.String() - for field in link.fields - } - required = [ - field.name - for field in link.fields - if field.required - ] + properties = {field.name: field.schema or coreschema.String() for field in link.fields} + required = [field.name for field in link.fields if field.required] schema = coreschema.Object(properties=properties, required=required) return mark_safe(coreschema.render_to_form(schema)) - @register.simple_tag def render_markdown(markdown_text): if apply_markdown is None: return markdown_text return mark_safe(apply_markdown(markdown_text)) - @register.simple_tag def get_pagination_html(pager): return pager.to_html() - @register.simple_tag def render_form(serializer, template_pack=None): style = {'template_pack': template_pack} if template_pack else {} renderer = HTMLFormRenderer() return renderer.render(serializer.data, None, {'style': style}) - @register.simple_tag def render_field(field, style): renderer = style.get('renderer', HTMLFormRenderer()) return renderer.render_field(field, style) - @register.simple_tag def optional_login(request): """ @@ -95,13 +75,10 @@ def optional_login(request): login_url = reverse('rest_framework:login') except NoReverseMatch: return '' - snippet = "
  • Log in
  • " snippet = format_html(snippet, href=login_url, next=escape(request.path)) - return mark_safe(snippet) - @register.simple_tag def optional_docs_login(request): """ @@ -111,13 +88,10 @@ def optional_docs_login(request): login_url = reverse('rest_framework:login') except NoReverseMatch: return 'log in' - snippet = "log in" snippet = format_html(snippet, href=login_url, next=escape(request.path)) - return mark_safe(snippet) - @register.simple_tag def optional_logout(request, user, csrf_token): """ @@ -128,7 +102,6 @@ def optional_logout(request, user, csrf_token): except NoReverseMatch: snippet = format_html('', user=escape(user)) return mark_safe(snippet) - snippet = """ """ - snippet = format_html(snippet, user=escape(user), href=logout_url, - next=escape(request.path), csrf_token=csrf_token) + snippet = format_html(snippet, user=escape(user), href=logout_url, next=escape(request.path), csrf_token=csrf_token) return mark_safe(snippet) - @register.simple_tag def add_query_param(request, key, val): """ @@ -157,170 +128,98 @@ def add_query_param(request, key, val): uri = iri_to_uri(iri) return escape(replace_query_param(uri, key, val)) - @register.filter def as_string(value): - if value is None: - return '' - return '%s' % value - + return '' if value is None else '%s' % value @register.filter def as_list_of_strings(value): - return [ - '' if (item is None) else ('%s' % item) - for item in value - ] - + return ['' if item is None else '%s' % item for item in value] @register.filter def add_class(value, css_class): - """ - https://stackoverflow.com/questions/4124220/django-adding-css-classes-when-rendering-form-fields-in-a-template - - Inserts classes into template variables that contain HTML tags, - useful for modifying forms without needing to change the Form objects. - - Usage: - - {{ field.label_tag|add_class:"control-label" }} - - In the case of REST Framework, the filter is used to add Bootstrap-specific - classes to the forms. - """ html = str(value) match = class_re.search(html) if match: - m = re.search(r'^%s$|^%s\s|\s%s\s|\s%s$' % (css_class, css_class, - css_class, css_class), - match.group(1)) - if not m: - return mark_safe(class_re.sub(match.group(1) + " " + css_class, - html)) + classes = match.group(1) + if css_class not in classes.split(): + classes += f" {css_class}" + html = class_re.sub(classes, html) else: - return mark_safe(html.replace('>', ' class="%s">' % css_class, 1)) - return value - + html = html.replace('>', f' class="{css_class}">', 1) + return mark_safe(html) @register.filter def format_value(value): if getattr(value, 'is_hyperlink', False): name = str(value.obj) - return mark_safe('%s' % (value, escape(name))) + return mark_safe(f'{escape(name)}') if value is None or isinstance(value, bool): - return mark_safe('%s' % {True: 'true', False: 'false', None: 'null'}[value]) - elif isinstance(value, list): + return mark_safe(f'{value}') + if isinstance(value, list): if any(isinstance(item, (list, dict)) for item in value): template = loader.get_template('rest_framework/admin/list_value.html') else: template = loader.get_template('rest_framework/admin/simple_list_value.html') - context = {'value': value} - return template.render(context) - elif isinstance(value, dict): + return template.render({'value': value}) + if isinstance(value, dict): template = loader.get_template('rest_framework/admin/dict_value.html') - context = {'value': value} - return template.render(context) - elif isinstance(value, str): - if ( - (value.startswith('http:') or value.startswith('https:') or value.startswith('/')) and not - re.search(r'\s', value) - ): - return mark_safe('{value}'.format(value=escape(value))) - elif '@' in value and not re.search(r'\s', value): - return mark_safe('{value}'.format(value=escape(value))) - elif '\n' in value: - return mark_safe('
    %s
    ' % escape(value)) + return template.render({'value': value}) + if isinstance(value, str): + if (value.startswith('http') or value.startswith('/')) and not re.search(r'\s', value): + return mark_safe(f'{escape(value)}') + if '@' in value and not re.search(r'\s', value): + return mark_safe(f'{escape(value)}') + if '\n' in value: + return mark_safe(f'
    {escape(value)}
    ') return str(value) - @register.filter def items(value): - """ - Simple filter to return the items of the dict. Useful when the dict may - have a key 'items' which is resolved first in Django template dot-notation - lookup. See issue #4931 - Also see: https://stackoverflow.com/questions/15416662/django-template-loop-over-dictionary-items-with-items-as-key - """ - if value is None: - # `{% for k, v in value.items %}` doesn't raise when value is None or - # not in the context, so neither should `{% for k, v in value|items %}` - return [] - return value.items() - + return [] if value is None else value.items() @register.filter def data(value): - """ - Simple filter to access `data` attribute of object, - specifically coreapi.Document. - - As per `items` filter above, allows accessing `document.data` when - Document contains Link keyed-at "data". - - See issue #5395 - """ return value.data - @register.filter def schema_links(section, sec_key=None): """ Recursively find every link in a schema, even nested. """ - NESTED_FORMAT = '%s > %s' # this format is used in docs/js/api.js:normalizeKeys + NESTED_FORMAT = '%s > %s' links = section.links if section.data: data = section.data.items() for sub_section_key, sub_section in data: new_links = schema_links(sub_section, sec_key=sub_section_key) links.update(new_links) - if sec_key is not None: - new_links = {} - for link_key, link in links.items(): - new_key = NESTED_FORMAT % (sec_key, link_key) - new_links.update({new_key: link}) + new_links = {NESTED_FORMAT % (sec_key, link_key): link for link_key, link in links.items()} return new_links - return links - @register.filter def add_nested_class(value): - if isinstance(value, dict): - return 'class=nested' - if isinstance(value, list) and any(isinstance(item, (list, dict)) for item in value): + if isinstance(value, dict) or (isinstance(value, list) and any(isinstance(item, (list, dict)) for item in value)): return 'class=nested' return '' - -# Bunch of stuff cloned from urlize -TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}", "'"] -WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), - ('"', '"'), ("'", "'")] +TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}"] +WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), ('"', '"'), ("'", "'")] word_split_re = re.compile(r'(\s+)') simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE) simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE) simple_email_re = re.compile(r'^\S+@\S+\.\S+$') - def smart_urlquote_wrapper(matched_url): - """ - Simple wrapper for smart_urlquote. ValueError("Invalid IPv6 URL") can - be raised here, see issue #1386 - """ try: return smart_urlquote(matched_url) except ValueError: return None - @register.filter def break_long_headers(header): - """ - Breaks headers longer than 160 characters (~page length) - when possible (are comma separated) - """ if len(header) > 160 and ',' in header: header = mark_safe('
    ' + ',
    '.join(escape(header).split(','))) return header