diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 4eec70bda..850bb408e 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -53,6 +53,50 @@ def flatten(list_of_lists): return itertools.chain(*list_of_lists) +class RouteTree: + """ + Stores a tree structure of routes. + """ + Node = namedtuple('Node', ['routes', 'value']) + + def __init__(self): + self.routes = {} + + def set(self, value, path, name): + """ + Set a value on a given path. Raises a KeyError if the value has already + been set or if the path doesn't exist. + """ + routes = self.routes + + for part in path: + if part not in routes: + # TODO: invalid path error message + raise KeyError('') + routes = routes[path].routes + + if name in routes: + # TODO: route name already exists error message + raise KeyError('') + + routes[name] = self.Node({}, value) + + def get(self, path, name): + """ + Get a vector of (name, value) for each entry on the path and name. + Raises KeyError if path or name is invalid. + """ + routes = self.routes + vect = [] + + for part in path + [name]: + node = routes[part] + vect.append((part, node.value)) + routes = node.routes + + return vect + + class BaseRouter(object): def __init__(self): self.registry = [] @@ -234,13 +278,70 @@ class SimpleRouter(BaseRouter): lookup_value=lookup_value ) + def get_route_nodes_path(self, route_nodes): + """ + Given a vector of (prefix, viewset) tuples, get the resulting URL + regex. The URL kwarg name and lookup is determined by the prefix + and settings for each view. + + Duplicate prefixes on the path will be appended with a number. For + example, the path: + [('users', UserView), ('users', UserView), ('friends', UserView)] + would result in: + /users/(?P[^\]+)/users/(?P[^/]+)/friends + """ + encountered = {} + parts = [] + + for i, node in enumerate(route_nodes): + prefix, viewset = node + if i == len(route_nodes) - 1: + parts.append(prefix) + + else: + lookup_prefix = '%s_' % prefix + if prefix in encountered: + encountered[prefix] += 1 + lookup_prefix = '%s_%s_' % (prefix, encountered[prefix]) + else: + encountered[prefix] = 1 + parts.append('%s/%s' % (prefix, self.get_lookup_regex(viewset, lookup_prefix))) + + return '/'.join(parts) + + def _process_registry(self): + """ + New-style routes use a list of strings for the prefix. Normalize old- + style string prefixes and then sort based on path length and prefix. + + If a viewset was registered for all nodes that have nested nodes then + the result should construct a valid RouteTree when inserted in order. + """ + def normalize(entry): + prefix, viewset, base_name = entry + if isinstance(prefix, str): + prefix = [prefix] + return prefix, viewset, base_name + + def sort(entry): + full_path = entry[0] + prefix, path = full_path[-1], full_path[:-1] + return len(path), prefix + + return sorted(map(normalize, self.registry), key=sort) + def get_urls(self): """ Use the registered viewsets to generate a list of URL patterns. """ - ret = [] + urls = [] + route_tree = RouteTree() - for prefix, viewset, basename in self.registry: + for full_path, viewset, basename in self._process_registry(): + prefix, path = full_path[-1], full_path[:-1] + route_tree.set(viewset, path, prefix) + route_nodes = route_tree.get(path, prefix) + prefix = self.get_route_nodes_path(route_nodes) lookup = self.get_lookup_regex(viewset) routes = self.get_routes(viewset) @@ -260,9 +361,9 @@ class SimpleRouter(BaseRouter): view = viewset.as_view(mapping, **route.initkwargs) name = route.name.format(basename=basename) - ret.append(url(regex, view, name=name)) + urls.append(url(regex, view, name=name)) - return ret + return urls class DefaultRouter(SimpleRouter):