mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-10-24 20:51:19 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			240 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			240 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| generators.py   # Top-down schema generation
 | |
| 
 | |
| See schemas.__init__.py for package overview.
 | |
| """
 | |
| import re
 | |
| from importlib import import_module
 | |
| 
 | |
| from django.conf import settings
 | |
| from django.contrib.admindocs.views import simplify_regex
 | |
| from django.core.exceptions import PermissionDenied
 | |
| from django.http import Http404
 | |
| from django.urls import URLPattern, URLResolver
 | |
| 
 | |
| from rest_framework import exceptions
 | |
| from rest_framework.request import clone_request
 | |
| from rest_framework.settings import api_settings
 | |
| from rest_framework.utils.model_meta import _get_pk
 | |
| 
 | |
| 
 | |
| def get_pk_name(model):
 | |
|     meta = model._meta.concrete_model._meta
 | |
|     return _get_pk(meta).name
 | |
| 
 | |
| 
 | |
| def is_api_view(callback):
 | |
|     """
 | |
|     Return `True` if the given view callback is a REST framework view/viewset.
 | |
|     """
 | |
|     # Avoid import cycle on APIView
 | |
|     from rest_framework.views import APIView
 | |
|     cls = getattr(callback, 'cls', None)
 | |
|     return (cls is not None) and issubclass(cls, APIView)
 | |
| 
 | |
| 
 | |
| def endpoint_ordering(endpoint):
 | |
|     path, method, callback = endpoint
 | |
|     method_priority = {
 | |
|         'GET': 0,
 | |
|         'POST': 1,
 | |
|         'PUT': 2,
 | |
|         'PATCH': 3,
 | |
|         'DELETE': 4
 | |
|     }.get(method, 5)
 | |
|     return (method_priority,)
 | |
| 
 | |
| 
 | |
| _PATH_PARAMETER_COMPONENT_RE = re.compile(
 | |
|     r'<(?:(?P<converter>[^>:]+):)?(?P<parameter>\w+)>'
 | |
| )
 | |
| 
 | |
| 
 | |
| class EndpointEnumerator:
 | |
|     """
 | |
|     A class to determine the available API endpoints that a project exposes.
 | |
|     """
 | |
|     def __init__(self, patterns=None, urlconf=None):
 | |
|         if patterns is None:
 | |
|             if urlconf is None:
 | |
|                 # Use the default Django URL conf
 | |
|                 urlconf = settings.ROOT_URLCONF
 | |
| 
 | |
|             # Load the given URLconf module
 | |
|             if isinstance(urlconf, str):
 | |
|                 urls = import_module(urlconf)
 | |
|             else:
 | |
|                 urls = urlconf
 | |
|             patterns = urls.urlpatterns
 | |
| 
 | |
|         self.patterns = patterns
 | |
| 
 | |
|     def get_api_endpoints(self, patterns=None, prefix=''):
 | |
|         """
 | |
|         Return a list of all available API endpoints by inspecting the URL conf.
 | |
|         """
 | |
|         if patterns is None:
 | |
|             patterns = self.patterns
 | |
| 
 | |
|         api_endpoints = []
 | |
| 
 | |
|         for pattern in patterns:
 | |
|             path_regex = prefix + str(pattern.pattern)
 | |
|             if isinstance(pattern, URLPattern):
 | |
|                 path = self.get_path_from_regex(path_regex)
 | |
|                 callback = pattern.callback
 | |
|                 if self.should_include_endpoint(path, callback):
 | |
|                     for method in self.get_allowed_methods(callback):
 | |
|                         endpoint = (path, method, callback)
 | |
|                         api_endpoints.append(endpoint)
 | |
| 
 | |
|             elif isinstance(pattern, URLResolver):
 | |
|                 nested_endpoints = self.get_api_endpoints(
 | |
|                     patterns=pattern.url_patterns,
 | |
|                     prefix=path_regex
 | |
|                 )
 | |
|                 api_endpoints.extend(nested_endpoints)
 | |
| 
 | |
|         return sorted(api_endpoints, key=endpoint_ordering)
 | |
| 
 | |
|     def get_path_from_regex(self, path_regex):
 | |
|         """
 | |
|         Given a URL conf regex, return a URI template string.
 | |
|         """
 | |
|         # ???: Would it be feasible to adjust this such that we generate the
 | |
|         # path, plus the kwargs, plus the type from the converter, such that we
 | |
|         # could feed that straight into the parameter schema object?
 | |
| 
 | |
|         path = simplify_regex(path_regex)
 | |
| 
 | |
|         # Strip Django 2.0 converters as they are incompatible with uritemplate format
 | |
|         return re.sub(_PATH_PARAMETER_COMPONENT_RE, r'{\g<parameter>}', path)
 | |
| 
 | |
|     def should_include_endpoint(self, path, callback):
 | |
|         """
 | |
|         Return `True` if the given endpoint should be included.
 | |
|         """
 | |
|         if not is_api_view(callback):
 | |
|             return False  # Ignore anything except REST framework views.
 | |
| 
 | |
|         if callback.cls.schema is None:
 | |
|             return False
 | |
| 
 | |
|         if 'schema' in callback.initkwargs:
 | |
|             if callback.initkwargs['schema'] is None:
 | |
|                 return False
 | |
| 
 | |
|         if path.endswith('.{format}') or path.endswith('.{format}/'):
 | |
|             return False  # Ignore .json style URLs.
 | |
| 
 | |
|         return True
 | |
| 
 | |
|     def get_allowed_methods(self, callback):
 | |
|         """
 | |
|         Return a list of the valid HTTP methods for this endpoint.
 | |
|         """
 | |
|         if hasattr(callback, 'actions'):
 | |
|             actions = set(callback.actions)
 | |
|             http_method_names = set(callback.cls.http_method_names)
 | |
|             methods = [method.upper() for method in actions & http_method_names]
 | |
|         else:
 | |
|             methods = callback.cls().allowed_methods
 | |
| 
 | |
|         return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
 | |
| 
 | |
| 
 | |
| class BaseSchemaGenerator:
 | |
|     endpoint_inspector_cls = EndpointEnumerator
 | |
| 
 | |
|     # 'pk' isn't great as an externally exposed name for an identifier,
 | |
|     # so by default we prefer to use the actual model field name for schemas.
 | |
|     # Set by 'SCHEMA_COERCE_PATH_PK'.
 | |
|     coerce_path_pk = None
 | |
| 
 | |
|     def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
 | |
|         if url and not url.endswith('/'):
 | |
|             url += '/'
 | |
| 
 | |
|         self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
 | |
| 
 | |
|         self.patterns = patterns
 | |
|         self.urlconf = urlconf
 | |
|         self.title = title
 | |
|         self.description = description
 | |
|         self.version = version
 | |
|         self.url = url
 | |
|         self.endpoints = None
 | |
| 
 | |
|     def _initialise_endpoints(self):
 | |
|         if self.endpoints is None:
 | |
|             inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
 | |
|             self.endpoints = inspector.get_api_endpoints()
 | |
| 
 | |
|     def _get_paths_and_endpoints(self, request):
 | |
|         """
 | |
|         Generate (path, method, view) given (path, method, callback) for paths.
 | |
|         """
 | |
|         paths = []
 | |
|         view_endpoints = []
 | |
|         for path, method, callback in self.endpoints:
 | |
|             view = self.create_view(callback, method, request)
 | |
|             path = self.coerce_path(path, method, view)
 | |
|             paths.append(path)
 | |
|             view_endpoints.append((path, method, view))
 | |
| 
 | |
|         return paths, view_endpoints
 | |
| 
 | |
|     def create_view(self, callback, method, request=None):
 | |
|         """
 | |
|         Given a callback, return an actual view instance.
 | |
|         """
 | |
|         view = callback.cls(**getattr(callback, 'initkwargs', {}))
 | |
|         view.args = ()
 | |
|         view.kwargs = {}
 | |
|         view.format_kwarg = None
 | |
|         view.request = None
 | |
|         view.action_map = getattr(callback, 'actions', None)
 | |
| 
 | |
|         actions = getattr(callback, 'actions', None)
 | |
|         if actions is not None:
 | |
|             if method == 'OPTIONS':
 | |
|                 view.action = 'metadata'
 | |
|             else:
 | |
|                 view.action = actions.get(method.lower())
 | |
| 
 | |
|         if request is not None:
 | |
|             view.request = clone_request(request, method)
 | |
| 
 | |
|         return view
 | |
| 
 | |
|     def coerce_path(self, path, method, view):
 | |
|         """
 | |
|         Coerce {pk} path arguments into the name of the model field,
 | |
|         where possible. This is cleaner for an external representation.
 | |
|         (Ie. "this is an identifier", not "this is a database primary key")
 | |
|         """
 | |
|         if not self.coerce_path_pk or '{pk}' not in path:
 | |
|             return path
 | |
|         model = getattr(getattr(view, 'queryset', None), 'model', None)
 | |
|         if model:
 | |
|             field_name = get_pk_name(model)
 | |
|         else:
 | |
|             field_name = 'id'
 | |
|         return path.replace('{pk}', '{%s}' % field_name)
 | |
| 
 | |
|     def get_schema(self, request=None, public=False):
 | |
|         raise NotImplementedError(".get_schema() must be implemented in subclasses.")
 | |
| 
 | |
|     def has_view_permissions(self, path, method, view):
 | |
|         """
 | |
|         Return `True` if the incoming request has the correct view permissions.
 | |
|         """
 | |
|         if view.request is None:
 | |
|             return True
 | |
| 
 | |
|         try:
 | |
|             view.check_permissions(view.request)
 | |
|         except (exceptions.APIException, Http404, PermissionDenied):
 | |
|             return False
 | |
|         return True
 |