From 21a02573937fbe0a66d077ce043e49a159c7471d Mon Sep 17 00:00:00 2001 From: Arnav Choudhury Date: Sat, 30 Jan 2021 16:24:39 +0530 Subject: [PATCH] Refactored test_{urls,views,models} to use django's built-in get_user_model() method to get the custom user model instead of importing it from a specific django app. Makes the code less prone to breaking. Updated the code for test_user_count method to handle the case of existing users in the database since the local development db can be used to run tests. --- .../users/tests/test_models.py | 4 ++-- .../users/tests/test_tasks.py | 16 ++++++++++++---- .../users/tests/test_urls.py | 4 ++-- .../users/tests/test_views.py | 3 ++- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_models.py b/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_models.py index 3194be1fd..e7180db16 100644 --- a/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_models.py +++ b/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_models.py @@ -1,8 +1,8 @@ import pytest - -from {{ cookiecutter.project_slug }}.users.models import User +from django.contrib.auth import get_user_model pytestmark = pytest.mark.django_db +User = get_user_model() def test_user_get_absolute_url(user: User): diff --git a/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_tasks.py b/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_tasks.py index 41d5af292..8cf40d088 100644 --- a/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_tasks.py +++ b/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_tasks.py @@ -9,8 +9,16 @@ pytestmark = pytest.mark.django_db def test_user_count(settings): """A basic test to execute the get_users_count Celery task.""" - UserFactory.create_batch(3) settings.CELERY_TASK_ALWAYS_EAGER = True - task_result = get_users_count.delay() - assert isinstance(task_result, EagerResult) - assert task_result.result == 3 + + # Get all existing users in the DB + current_users = get_users_count.delay() + assert isinstance(current_users, EagerResult) + + # Create and add 3 more users to the DB. + UserFactory.create_batch(3) + + # Get number of newly added users in the DB + task_result = get_users_count.delay().result - current_users.result + + assert task_result == 3 diff --git a/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_urls.py b/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_urls.py index aab6d0a87..f5cc874a4 100644 --- a/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_urls.py +++ b/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_urls.py @@ -1,9 +1,9 @@ import pytest +from django.contrib.auth import get_user_model from django.urls import resolve, reverse -from {{ cookiecutter.project_slug }}.users.models import User - pytestmark = pytest.mark.django_db +User = get_user_model() def test_detail(user: User): diff --git a/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_views.py b/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_views.py index c2fe8b519..075cbf8df 100644 --- a/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_views.py +++ b/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/users/tests/test_views.py @@ -1,5 +1,6 @@ import pytest from django.contrib import messages +from django.contrib.auth import get_user_model from django.contrib.auth.models import AnonymousUser from django.contrib.messages.middleware import MessageMiddleware from django.contrib.sessions.middleware import SessionMiddleware @@ -7,7 +8,6 @@ from django.http.response import Http404 from django.test import RequestFactory from {{ cookiecutter.project_slug }}.users.forms import UserChangeForm -from {{ cookiecutter.project_slug }}.users.models import User from {{ cookiecutter.project_slug }}.users.tests.factories import UserFactory from {{ cookiecutter.project_slug }}.users.views import ( UserRedirectView, @@ -16,6 +16,7 @@ from {{ cookiecutter.project_slug }}.users.views import ( ) pytestmark = pytest.mark.django_db +User = get_user_model() class TestUserUpdateView: