"""
Tests for the throttling implementations in the permissions module.
"""

import pytest
from django.contrib.auth.models import User
from django.core.cache import cache
from django.core.exceptions import ImproperlyConfigured
from django.http import HttpRequest
from django.test import TestCase

from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory, force_authenticate
from rest_framework.throttling import (
    AnonRateThrottle, BaseThrottle, ScopedRateThrottle, SimpleRateThrottle,
    UserRateThrottle
)
from rest_framework.views import APIView


class User3SecRateThrottle(UserRateThrottle):
    rate = '3/sec'
    scope = 'seconds'


class User3MinRateThrottle(UserRateThrottle):
    rate = '3/min'
    scope = 'minutes'


class User6MinRateThrottle(UserRateThrottle):
    rate = '6/min'
    scope = 'minutes'


class NonTimeThrottle(BaseThrottle):
    def allow_request(self, request, view):
        if not hasattr(self.__class__, 'called'):
            self.__class__.called = True
            return True
        return False


class MockView_DoubleThrottling(APIView):
    throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,)

    def get(self, request):
        return Response('foo')


class MockView(APIView):
    throttle_classes = (User3SecRateThrottle,)

    def get(self, request):
        return Response('foo')


class MockView_MinuteThrottling(APIView):
    throttle_classes = (User3MinRateThrottle,)

    def get(self, request):
        return Response('foo')


class MockView_NonTimeThrottling(APIView):
    throttle_classes = (NonTimeThrottle,)

    def get(self, request):
        return Response('foo')


class ThrottlingTests(TestCase):
    def setUp(self):
        """
        Reset the cache so that no throttles will be active
        """
        cache.clear()
        self.factory = APIRequestFactory()

    def test_requests_are_throttled(self):
        """
        Ensure request rate is limited
        """
        request = self.factory.get('/')
        for dummy in range(4):
            response = MockView.as_view()(request)
        assert response.status_code == 429

    def set_throttle_timer(self, view, value):
        """
        Explicitly set the timer, overriding time.time()
        """
        for cls in view.throttle_classes:
            cls.timer = lambda self: value

    def test_request_throttling_expires(self):
        """
        Ensure request rate is limited for a limited duration only
        """
        self.set_throttle_timer(MockView, 0)

        request = self.factory.get('/')
        for dummy in range(4):
            response = MockView.as_view()(request)
        assert response.status_code == 429

        # Advance the timer by one second
        self.set_throttle_timer(MockView, 1)

        response = MockView.as_view()(request)
        assert response.status_code == 200

    def ensure_is_throttled(self, view, expect):
        request = self.factory.get('/')
        request.user = User.objects.create(username='a')
        for dummy in range(3):
            view.as_view()(request)
        request.user = User.objects.create(username='b')
        response = view.as_view()(request)
        assert response.status_code == expect

    def test_request_throttling_is_per_user(self):
        """
        Ensure request rate is only limited per user, not globally for
        PerUserThrottles
        """
        self.ensure_is_throttled(MockView, 200)

    def test_request_throttling_multiple_throttles(self):
        """
        Ensure all throttle classes see each request even when the request is
        already being throttled
        """
        self.set_throttle_timer(MockView_DoubleThrottling, 0)
        request = self.factory.get('/')
        for dummy in range(4):
            response = MockView_DoubleThrottling.as_view()(request)
        assert response.status_code == 429
        assert int(response['retry-after']) == 1

        # At this point our client made 4 requests (one was throttled) in a
        # second. If we advance the timer by one additional second, the client
        # should be allowed to make 2 more before being throttled by the 2nd
        # throttle class, which has a limit of 6 per minute.
        self.set_throttle_timer(MockView_DoubleThrottling, 1)
        for dummy in range(2):
            response = MockView_DoubleThrottling.as_view()(request)
            assert response.status_code == 200

        response = MockView_DoubleThrottling.as_view()(request)
        assert response.status_code == 429
        assert int(response['retry-after']) == 59

        # Just to make sure check again after two more seconds.
        self.set_throttle_timer(MockView_DoubleThrottling, 2)
        response = MockView_DoubleThrottling.as_view()(request)
        assert response.status_code == 429
        assert int(response['retry-after']) == 58

    def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
        """
        Ensure the response returns an Retry-After field with status and next attributes
        set properly.
        """
        request = self.factory.get('/')
        for timer, expect in expected_headers:
            self.set_throttle_timer(view, timer)
            response = view.as_view()(request)
            if expect is not None:
                assert response['Retry-After'] == expect
            else:
                assert not'Retry-After' in response

    def test_seconds_fields(self):
        """
        Ensure for second based throttles.
        """
        self.ensure_response_header_contains_proper_throttle_field(
            MockView, (
                (0, None),
                (0, None),
                (0, None),
                (0, '1')
            )
        )

    def test_minutes_fields(self):
        """
        Ensure for minute based throttles.
        """
        self.ensure_response_header_contains_proper_throttle_field(
            MockView_MinuteThrottling, (
                (0, None),
                (0, None),
                (0, None),
                (0, '60')
            )
        )

    def test_next_rate_remains_constant_if_followed(self):
        """
        If a client follows the recommended next request rate,
        the throttling rate should stay constant.
        """
        self.ensure_response_header_contains_proper_throttle_field(
            MockView_MinuteThrottling, (
                (0, None),
                (20, None),
                (40, None),
                (60, None),
                (80, None)
            )
        )

    def test_non_time_throttle(self):
        """
        Ensure for second based throttles.
        """
        request = self.factory.get('/')

        self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))

        response = MockView_NonTimeThrottling.as_view()(request)
        self.assertFalse('Retry-After' in response)

        self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)

        response = MockView_NonTimeThrottling.as_view()(request)
        self.assertFalse('Retry-After' in response)


class ScopedRateThrottleTests(TestCase):
    """
    Tests for ScopedRateThrottle.
    """

    def setUp(self):
        self.throttle = ScopedRateThrottle()

        class XYScopedRateThrottle(ScopedRateThrottle):
            TIMER_SECONDS = 0
            THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}

            def timer(self):
                return self.TIMER_SECONDS

        class XView(APIView):
            throttle_classes = (XYScopedRateThrottle,)
            throttle_scope = 'x'

            def get(self, request):
                return Response('x')

        class YView(APIView):
            throttle_classes = (XYScopedRateThrottle,)
            throttle_scope = 'y'

            def get(self, request):
                return Response('y')

        class UnscopedView(APIView):
            throttle_classes = (XYScopedRateThrottle,)

            def get(self, request):
                return Response('y')

        self.throttle_class = XYScopedRateThrottle
        self.factory = APIRequestFactory()
        self.x_view = XView.as_view()
        self.y_view = YView.as_view()
        self.unscoped_view = UnscopedView.as_view()

    def increment_timer(self, seconds=1):
        self.throttle_class.TIMER_SECONDS += seconds

    def test_scoped_rate_throttle(self):
        request = self.factory.get('/')

        # Should be able to hit x view 3 times per minute.
        response = self.x_view(request)
        assert response.status_code == 200

        self.increment_timer()
        response = self.x_view(request)
        assert response.status_code == 200

        self.increment_timer()
        response = self.x_view(request)
        assert response.status_code == 200
        self.increment_timer()
        response = self.x_view(request)
        assert response.status_code == 429

        # Should be able to hit y view 1 time per minute.
        self.increment_timer()
        response = self.y_view(request)
        assert response.status_code == 200

        self.increment_timer()
        response = self.y_view(request)
        assert response.status_code == 429

        # Ensure throttles properly reset by advancing the rest of the minute
        self.increment_timer(55)

        # Should still be able to hit x view 3 times per minute.
        response = self.x_view(request)
        assert response.status_code == 200

        self.increment_timer()
        response = self.x_view(request)
        assert response.status_code == 200

        self.increment_timer()
        response = self.x_view(request)
        assert response.status_code == 200

        self.increment_timer()
        response = self.x_view(request)
        assert response.status_code == 429

        # Should still be able to hit y view 1 time per minute.
        self.increment_timer()
        response = self.y_view(request)
        assert response.status_code == 200

        self.increment_timer()
        response = self.y_view(request)
        assert response.status_code == 429

    def test_unscoped_view_not_throttled(self):
        request = self.factory.get('/')

        for idx in range(10):
            self.increment_timer()
            response = self.unscoped_view(request)
            assert response.status_code == 200

    def test_get_cache_key_returns_correct_key_if_user_is_authenticated(self):
        class DummyView:
            throttle_scope = 'user'

        request = Request(HttpRequest())
        user = User.objects.create(username='test')
        force_authenticate(request, user)
        request.user = user
        self.throttle.allow_request(request, DummyView())
        cache_key = self.throttle.get_cache_key(request, view=DummyView())
        assert cache_key == 'throttle_user_%s' % user.pk


class XffTestingBase(TestCase):
    def setUp(self):

        class Throttle(ScopedRateThrottle):
            THROTTLE_RATES = {'test_limit': '1/day'}
            TIMER_SECONDS = 0

            def timer(self):
                return self.TIMER_SECONDS

        class View(APIView):
            throttle_classes = (Throttle,)
            throttle_scope = 'test_limit'

            def get(self, request):
                return Response('test_limit')

        cache.clear()
        self.throttle = Throttle()
        self.view = View.as_view()
        self.request = APIRequestFactory().get('/some_uri')
        self.request.META['REMOTE_ADDR'] = '3.3.3.3'
        self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 2.2.2.2'

    def config_proxy(self, num_proxies):
        setattr(api_settings, 'NUM_PROXIES', num_proxies)


class IdWithXffBasicTests(XffTestingBase):
    def test_accepts_request_under_limit(self):
        self.config_proxy(0)
        assert self.view(self.request).status_code == 200

    def test_denies_request_over_limit(self):
        self.config_proxy(0)
        self.view(self.request)
        assert self.view(self.request).status_code == 429


class XffSpoofingTests(XffTestingBase):
    def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self):
        self.config_proxy(1)
        self.view(self.request)
        self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2'
        assert self.view(self.request).status_code == 429

    def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self):
        self.config_proxy(2)
        self.view(self.request)
        self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2'
        assert self.view(self.request).status_code == 429


class XffUniqueMachinesTest(XffTestingBase):
    def test_unique_clients_are_counted_independently_with_one_proxy(self):
        self.config_proxy(1)
        self.view(self.request)
        self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7'
        assert self.view(self.request).status_code == 200

    def test_unique_clients_are_counted_independently_with_two_proxies(self):
        self.config_proxy(2)
        self.view(self.request)
        self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2'
        assert self.view(self.request).status_code == 200


class BaseThrottleTests(TestCase):

    def test_allow_request_raises_not_implemented_error(self):
        with pytest.raises(NotImplementedError):
            BaseThrottle().allow_request(request={}, view={})


class SimpleRateThrottleTests(TestCase):

    def setUp(self):
        SimpleRateThrottle.scope = 'anon'

    def test_get_rate_raises_error_if_scope_is_missing(self):
        throttle = SimpleRateThrottle()
        with pytest.raises(ImproperlyConfigured):
            throttle.scope = None
            throttle.get_rate()

    def test_throttle_raises_error_if_rate_is_missing(self):
        SimpleRateThrottle.scope = 'invalid scope'
        with pytest.raises(ImproperlyConfigured):
            SimpleRateThrottle()

    def test_parse_rate_returns_tuple_with_none_if_rate_not_provided(self):
        rate = SimpleRateThrottle().parse_rate(None)
        assert rate == (None, None)

    def test_allow_request_returns_true_if_rate_is_none(self):
        assert SimpleRateThrottle().allow_request(request={}, view={}) is True

    def test_get_cache_key_raises_not_implemented_error(self):
        with pytest.raises(NotImplementedError):
            SimpleRateThrottle().get_cache_key({}, {})

    def test_allow_request_returns_true_if_key_is_none(self):
        throttle = SimpleRateThrottle()
        throttle.rate = 'some rate'
        throttle.get_cache_key = lambda *args: None
        assert throttle.allow_request(request={}, view={}) is True

    def test_wait_returns_correct_waiting_time_without_history(self):
        throttle = SimpleRateThrottle()
        throttle.num_requests = 1
        throttle.duration = 60
        throttle.history = []
        waiting_time = throttle.wait()
        assert isinstance(waiting_time, float)
        assert waiting_time == 30.0

    def test_wait_returns_none_if_there_are_no_available_requests(self):
        throttle = SimpleRateThrottle()
        throttle.num_requests = 1
        throttle.duration = 60
        throttle.now = throttle.timer()
        throttle.history = [throttle.timer() for _ in range(3)]
        assert throttle.wait() is None


class AnonRateThrottleTests(TestCase):

    def setUp(self):
        self.throttle = AnonRateThrottle()

    def test_authenticated_user_not_affected(self):
        request = Request(HttpRequest())
        user = User.objects.create(username='test')
        force_authenticate(request, user)
        request.user = user
        assert self.throttle.get_cache_key(request, view={}) is None

    def test_get_cache_key_returns_correct_value(self):
        request = Request(HttpRequest())
        cache_key = self.throttle.get_cache_key(request, view={})
        assert cache_key == 'throttle_anon_None'