From 8d99a4da51a502b932cfcc994f9784c4c71eb015 Mon Sep 17 00:00:00 2001 From: Ciro Monteiro Date: Sat, 13 Nov 2021 23:25:43 -0300 Subject: [PATCH] Add ModelTestCase to tests --- rest_framework/test.py | 42 ++++++++++++++++++++++++++++++++++++++++++ tests/test_testing.py | 31 ++++++++++++++++++++++++++++--- 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index 0212348ee..421b1e7f5 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -14,6 +14,7 @@ from django.test.client import RequestFactory as DjangoRequestFactory from django.utils.encoding import force_bytes from django.utils.http import urlencode +from rest_framework import status from rest_framework.compat import coreapi, requests from rest_framework.settings import api_settings @@ -410,3 +411,44 @@ class URLPatternsTestCase(testcases.SimpleTestCase): cls._module.urlpatterns = cls._module_urlpatterns else: del cls._module.urlpatterns + +class ModelTestCase(APITestCase): + model = None + + def __model_fields(self): + return [field.name for field in self.model._meta.local_fields[1:]] + + def __get_field_cases(self): + fields = self.__model_fields() + field_cases = {} + for field_name in fields: + try: + field_value = getattr(self, field_name) # must be an array + except: + raise AttributeError(f"TestCase doesn't have attribute '{field_name}'.") + field_cases[field_name] = field_value + + return field_cases + + def __cases(self): + fields = self.__model_fields() + field_cases = self.__get_field_cases() + cases = [{}] + for field_name in fields: + cases_updated = [] + for case in cases: + for field_case in field_cases[field_name]: + cases_updated.append({ **case, field_name: field_case }) + cases = cases_updated[:] + + return cases + + def test_post(self): + if self.model == None: + return + cases = self.__cases() + for case in cases: + response = self.client.post(self.url, case, format='json') + response.data.pop('id') + self.assertEqual(case, response.data) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) \ No newline at end of file diff --git a/tests/test_testing.py b/tests/test_testing.py index 5066ee142..3be1fc46e 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -1,16 +1,17 @@ from io import BytesIO import django +from django.db import models from django.contrib.auth.models import User from django.shortcuts import redirect from django.test import TestCase, override_settings from django.urls import path -from rest_framework import fields, serializers +from rest_framework import fields, serializers, viewsets from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.test import ( - APIClient, APIRequestFactory, URLPatternsTestCase, force_authenticate + APIClient, APIRequestFactory, URLPatternsTestCase, force_authenticate, ModelTestCase ) @@ -47,11 +48,27 @@ def post_view(request): return Response(serializer.validated_data) +class Person(models.Model): + first_name = models.CharField(max_length=100) + last_name = models.CharField(max_length=100) + age = models.PositiveSmallIntegerField() + +class PersonSerializer(serializers.ModelSerializer): + class Meta: + model = Person + fields = '__all__' + +class PersonViewSet(viewsets.ModelViewSet): + queryset = Person.objects.all() + serializer_class = PersonSerializer + + urlpatterns = [ path('view/', view), path('session-view/', session_view), path('redirect-view/', redirect_view), - path('post-view/', post_view) + path('post-view/', post_view), + path('persons/', PersonViewSet.as_view({'post': 'create'})) ] @@ -319,3 +336,11 @@ class TestExistingPatterns(TestCase): def test_urlpatterns(self): # sanity test to ensure that this test module does not have a '/' route assert self.client.get('/').status_code == 404 + +@override_settings(ROOT_URLCONF='tests.test_testing') +class TestModelTestCase(ModelTestCase): + model = Person + first_name = ["John", "Jane"] + last_name = ["Doe", "Roosevelt"] + age = [11, 23, 58, 13, 21] + url = "/persons/" \ No newline at end of file