mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-07 22:04:48 +03:00
Added RouteTree to routers, refactored SimpleRouter to support new-style and old-style route paths
This commit is contained in:
parent
e1768bdc16
commit
8f9f8b31d3
|
@ -53,6 +53,50 @@ def flatten(list_of_lists):
|
|||
return itertools.chain(*list_of_lists)
|
||||
|
||||
|
||||
class RouteTree:
|
||||
"""
|
||||
Stores a tree structure of routes.
|
||||
"""
|
||||
Node = namedtuple('Node', ['routes', 'value'])
|
||||
|
||||
def __init__(self):
|
||||
self.routes = {}
|
||||
|
||||
def set(self, value, path, name):
|
||||
"""
|
||||
Set a value on a given path. Raises a KeyError if the value has already
|
||||
been set or if the path doesn't exist.
|
||||
"""
|
||||
routes = self.routes
|
||||
|
||||
for part in path:
|
||||
if part not in routes:
|
||||
# TODO: invalid path error message
|
||||
raise KeyError('')
|
||||
routes = routes[path].routes
|
||||
|
||||
if name in routes:
|
||||
# TODO: route name already exists error message
|
||||
raise KeyError('')
|
||||
|
||||
routes[name] = self.Node({}, value)
|
||||
|
||||
def get(self, path, name):
|
||||
"""
|
||||
Get a vector of (name, value) for each entry on the path and name.
|
||||
Raises KeyError if path or name is invalid.
|
||||
"""
|
||||
routes = self.routes
|
||||
vect = []
|
||||
|
||||
for part in path + [name]:
|
||||
node = routes[part]
|
||||
vect.append((part, node.value))
|
||||
routes = node.routes
|
||||
|
||||
return vect
|
||||
|
||||
|
||||
class BaseRouter(object):
|
||||
def __init__(self):
|
||||
self.registry = []
|
||||
|
@ -234,13 +278,70 @@ class SimpleRouter(BaseRouter):
|
|||
lookup_value=lookup_value
|
||||
)
|
||||
|
||||
def get_route_nodes_path(self, route_nodes):
|
||||
"""
|
||||
Given a vector of (prefix, viewset) tuples, get the resulting URL
|
||||
regex. The URL kwarg name and lookup is determined by the prefix
|
||||
and settings for each view.
|
||||
|
||||
Duplicate prefixes on the path will be appended with a number. For
|
||||
example, the path:
|
||||
[('users', UserView), ('users', UserView), ('friends', UserView)]
|
||||
would result in:
|
||||
/users/(?P<users_pk>[^\]+)/users/(?P<users_2_pk>[^/]+)/friends
|
||||
"""
|
||||
encountered = {}
|
||||
parts = []
|
||||
|
||||
for i, node in enumerate(route_nodes):
|
||||
prefix, viewset = node
|
||||
if i == len(route_nodes) - 1:
|
||||
parts.append(prefix)
|
||||
|
||||
else:
|
||||
lookup_prefix = '%s_' % prefix
|
||||
if prefix in encountered:
|
||||
encountered[prefix] += 1
|
||||
lookup_prefix = '%s_%s_' % (prefix, encountered[prefix])
|
||||
else:
|
||||
encountered[prefix] = 1
|
||||
parts.append('%s/%s' % (prefix, self.get_lookup_regex(viewset, lookup_prefix)))
|
||||
|
||||
return '/'.join(parts)
|
||||
|
||||
def _process_registry(self):
|
||||
"""
|
||||
New-style routes use a list of strings for the prefix. Normalize old-
|
||||
style string prefixes and then sort based on path length and prefix.
|
||||
|
||||
If a viewset was registered for all nodes that have nested nodes then
|
||||
the result should construct a valid RouteTree when inserted in order.
|
||||
"""
|
||||
def normalize(entry):
|
||||
prefix, viewset, base_name = entry
|
||||
if isinstance(prefix, str):
|
||||
prefix = [prefix]
|
||||
return prefix, viewset, base_name
|
||||
|
||||
def sort(entry):
|
||||
full_path = entry[0]
|
||||
prefix, path = full_path[-1], full_path[:-1]
|
||||
return len(path), prefix
|
||||
|
||||
return sorted(map(normalize, self.registry), key=sort)
|
||||
|
||||
def get_urls(self):
|
||||
"""
|
||||
Use the registered viewsets to generate a list of URL patterns.
|
||||
"""
|
||||
ret = []
|
||||
urls = []
|
||||
route_tree = RouteTree()
|
||||
|
||||
for prefix, viewset, basename in self.registry:
|
||||
for full_path, viewset, basename in self._process_registry():
|
||||
prefix, path = full_path[-1], full_path[:-1]
|
||||
route_tree.set(viewset, path, prefix)
|
||||
route_nodes = route_tree.get(path, prefix)
|
||||
prefix = self.get_route_nodes_path(route_nodes)
|
||||
lookup = self.get_lookup_regex(viewset)
|
||||
routes = self.get_routes(viewset)
|
||||
|
||||
|
@ -260,9 +361,9 @@ class SimpleRouter(BaseRouter):
|
|||
|
||||
view = viewset.as_view(mapping, **route.initkwargs)
|
||||
name = route.name.format(basename=basename)
|
||||
ret.append(url(regex, view, name=name))
|
||||
urls.append(url(regex, view, name=name))
|
||||
|
||||
return ret
|
||||
return urls
|
||||
|
||||
|
||||
class DefaultRouter(SimpleRouter):
|
||||
|
|
Loading…
Reference in New Issue
Block a user