From 75777df4215954bb586926a3481e87bd4eec3132 Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 7 Apr 2025 22:25:27 +0300 Subject: [PATCH] test: Add tests for the GET api/business/promo/list endpoint. - Add tests that check different filtering methods - Add tests that check parameter validation - Add some additional parameter validation to the GET api/business/promo/list endpoint --- promo_code/business/pagination.py | 25 +- promo_code/business/tests/promocodes/base.py | 3 + .../tests/promocodes/operations/test_list.py | 313 ++++++++++++++++++ .../validations/test_list_validation.py | 58 ++++ promo_code/business/views.py | 77 ++++- 5 files changed, 449 insertions(+), 27 deletions(-) create mode 100644 promo_code/business/tests/promocodes/operations/test_list.py create mode 100644 promo_code/business/tests/promocodes/validations/test_list_validation.py diff --git a/promo_code/business/pagination.py b/promo_code/business/pagination.py index 61dfe05..58604fb 100644 --- a/promo_code/business/pagination.py +++ b/promo_code/business/pagination.py @@ -12,22 +12,15 @@ class CustomLimitOffsetPagination( def get_limit(self, request): param_limit = request.query_params.get(self.limit_query_param) if param_limit is not None: - try: - limit = int(param_limit) - if limit < 0: - raise rest_framework.exceptions.ValidationError( - 'Limit cannot be negative.', - ) - - if limit == 0: - return 0 - - if self.max_limit: - return min(limit, self.max_limit) - - return limit - except (TypeError, ValueError): - pass + limit = int(param_limit) + + if limit == 0: + return 0 + + if self.max_limit: + return min(limit, self.max_limit) + + return limit return self.default_limit diff --git a/promo_code/business/tests/promocodes/base.py b/promo_code/business/tests/promocodes/base.py index dbdd851..6a54d6c 100644 --- a/promo_code/business/tests/promocodes/base.py +++ b/promo_code/business/tests/promocodes/base.py @@ -12,6 +12,9 @@ def setUpTestData(cls): super().setUpTestData() cls.client = rest_framework.test.APIClient() cls.promo_create_url = django.urls.reverse('api-business:promo-create') + cls.promo_list_url = django.urls.reverse( + 'api-business:company-promo-list', + ) cls.signup_url = django.urls.reverse('api-business:company-sign-up') cls.signin_url = django.urls.reverse('api-business:company-sign-in') cls.valid_data = { diff --git a/promo_code/business/tests/promocodes/operations/test_list.py b/promo_code/business/tests/promocodes/operations/test_list.py new file mode 100644 index 0000000..da94886 --- /dev/null +++ b/promo_code/business/tests/promocodes/operations/test_list.py @@ -0,0 +1,313 @@ +import rest_framework.status +import rest_framework.test + +import business.models +import business.tests.promocodes.base + + +class TestPromoEndpoint( + business.tests.promocodes.base.BasePromoCreateTestCase, +): + def _create_additional_promo(self): + self.__class__.promo5_data = { + 'description': 'Special offer: bonus reward for loyal customers', + 'target': {'country': 'Kz'}, + 'max_count': 10, + 'active_from': '2026-05-01', + 'mode': 'COMMON', + 'promo_common': 'special-10', + } + response_create = self.client.post( + self.promo_create_url, + self.__class__.promo5_data, + format='json', + ) + self.assertEqual( + response_create.status_code, + rest_framework.status.HTTP_201_CREATED, + ) + + @classmethod + def setUpTestData(cls): + business.tests.promocodes.base.BasePromoCreateTestCase.setUpTestData() + + cls.valid_data = { + 'name': 'New Digital Marketing Solutions Inc.', + 'email': 'newtestcompany@example.com', + 'password': 'SecurePass123!', + } + + cls.company = business.models.Company.objects.create_company( + **cls.valid_data, + ) + + response = cls.client.post( + cls.signin_url, + { + 'email': cls.valid_data['email'], + 'password': cls.valid_data['password'], + }, + format='json', + ) + cls.new_token = response.data['access'] + + cls.promo1_data = { + 'description': 'Increased cashback 10% for new bank customers!', + 'image_url': 'https://cdn2.thecatapi.com/images/3lo.jpg', + 'target': {}, + 'max_count': 10, + 'active_from': '2025-01-10', + 'mode': 'COMMON', + 'promo_common': 'sale-10', + } + cls.promo2_data = { + 'description': 'Increased cashback 40% for new bank customers!', + 'image_url': 'https://cdn2.thecatapi.com/images/3lo.jpg', + 'target': {'age_from': 15, 'country': 'fr'}, + 'max_count': 100, + 'active_from': '2028-12-20', + 'mode': 'COMMON', + 'promo_common': 'sale-40', + } + cls.promo3_data = { + 'description': 'Gift sleep mask when applying for a car loan', + 'target': {'age_from': 28, 'age_until': 50, 'country': 'gb'}, + 'max_count': 1, + 'active_from': '2025-01-01', + 'active_until': '2028-12-30', + 'mode': 'UNIQUE', + 'promo_unique': ['uniq1', 'uniq2', 'uniq3'], + } + cls.promo5_data = { + 'description': 'Special offer: bonus reward for loyal customers', + 'target': {'country': 'Kz'}, + 'max_count': 10, + 'active_from': '2026-05-01', + 'mode': 'COMMON', + 'promo_common': 'special-10', + } + cls.created_promos = [] + + for promo_data in [cls.promo1_data, cls.promo2_data, cls.promo3_data]: + promo = business.models.Promo.objects.create( + company=cls.company, + description=promo_data['description'], + image_url=promo_data.get('image_url'), + target=promo_data['target'], + max_count=promo_data['max_count'], + active_from=promo_data.get('active_from'), + active_until=promo_data.get('active_until'), + mode=promo_data['mode'], + ) + if promo.mode == 'COMMON': + promo.promo_common = promo_data.get('promo_common') + promo.save() + else: + promo_codes = [ + business.models.PromoCode(promo=promo, code=code) + for code in promo_data.get('promo_unique', []) + ] + business.models.PromoCode.objects.bulk_create(promo_codes) + + cls.created_promos.append(promo) + + def setUp(self): + self.client = rest_framework.test.APIClient() + self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.new_token) + + def test_get_promos_without_token(self): + client = rest_framework.test.APIClient() + response = client.get(self.promo_list_url) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_401_UNAUTHORIZED, + ) + + def test_get_all_promos(self): + response = self.client.get(self.promo_list_url) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_200_OK, + ) + data = response.data + + self.assertEqual(len(data), 3) + self.assertEqual(data[0]['promo_id'], str(self.created_promos[2].id)) + self.assertEqual(data[1]['promo_id'], str(self.created_promos[1].id)) + self.assertEqual(data[2]['promo_id'], str(self.created_promos[0].id)) + self.assertEqual(response.headers.get('X-Total-Count'), '3') + + def test_get_promos_with_pagination_offset_1(self): + response = self.client.get(self.promo_list_url, {'offset': 1}) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_200_OK, + ) + data = response.data + + self.assertEqual(len(data), 2) + self.assertEqual(data[0]['promo_id'], str(self.created_promos[1].id)) + self.assertEqual(data[1]['promo_id'], str(self.created_promos[0].id)) + self.assertEqual(response.headers.get('X-Total-Count'), '3') + + def test_get_promos_with_pagination_offset_1_limit_1(self): + response = self.client.get( + self.promo_list_url, + {'offset': 1, 'limit': 1}, + ) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_200_OK, + ) + data = response.data + self.assertEqual(len(data), 1) + self.assertEqual(data[0]['promo_id'], str(self.created_promos[1].id)) + self.assertEqual(response.get('X-Total-Count'), '3') + + def test_get_promos_with_pagination_offset_100(self): + response = self.client.get(self.promo_list_url, {'offset': 100}) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_200_OK, + ) + data = response.data + self.assertEqual(len(data), 0) + self.assertEqual(response.get('X-Total-Count'), '3') + + def test_get_promos_filter_country_gb(self): + response = self.client.get(self.promo_list_url, {'country': 'gb'}) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_200_OK, + ) + data = response.data + + self.assertEqual(len(data), 2) + self.assertEqual(data[0]['promo_id'], str(self.created_promos[2].id)) + self.assertEqual(data[1]['promo_id'], str(self.created_promos[0].id)) + self.assertEqual(response.get('X-Total-Count'), '2') + + def test_get_promos_filter_country_gb_sort_active_until(self): + response = self.client.get( + self.promo_list_url, + {'country': 'gb', 'sort_by': 'active_until'}, + ) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_200_OK, + ) + data = response.data + + self.assertEqual(len(data), 2) + self.assertEqual(data[0]['promo_id'], str(self.created_promos[0].id)) + self.assertEqual(data[1]['promo_id'], str(self.created_promos[2].id)) + self.assertEqual(response.get('X-Total-Count'), '2') + + def test_get_promos_filter_country_gb_fr_sort_active_from_limit_10(self): + response = self.client.get( + self.promo_list_url, + {'country': 'gb,FR', 'sort_by': 'active_from', 'limit': 10}, + ) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_200_OK, + ) + data = response.data + + self.assertEqual(len(data), 3) + self.assertEqual(data[0]['promo_id'], str(self.created_promos[1].id)) + self.assertEqual(data[1]['promo_id'], str(self.created_promos[0].id)) + self.assertEqual(data[2]['promo_id'], str(self.created_promos[2].id)) + self.assertEqual(response.get('X-Total-Count'), '3') + + def test_get_promos_filter_country_gb_fr_sort_active_from_limit_2_offset_2( + self, + ): + response = self.client.get( + self.promo_list_url, + { + 'country': 'gb,FR', + 'sort_by': 'active_from', + 'limit': 2, + 'offset': 2, + }, + ) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_200_OK, + ) + data = response.data + self.assertEqual(len(data), 1) + self.assertEqual(data[0]['promo_id'], str(self.created_promos[2].id)) + self.assertEqual(response.get('X-Total-Count'), '3') + + def test_get_promos_filter_country_gb_fr_us_sort_active_from_limit_2(self): + response = self.client.get( + self.promo_list_url, + {'country': 'gb,FR,us', 'sort_by': 'active_from', 'limit': 2}, + ) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_200_OK, + ) + data = response.data + + self.assertEqual(len(data), 2) + self.assertEqual(data[0]['promo_id'], str(self.created_promos[1].id)) + self.assertEqual(data[1]['promo_id'], str(self.created_promos[0].id)) + self.assertEqual(response.get('X-Total-Count'), '3') + + def test_get_promos_limit_zero(self): + response = self.client.get(self.promo_list_url, {'limit': 0}) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_200_OK, + ) + data = response.data + self.assertEqual(len(data), 0) + + def test_create_and_get_promos(self): + self._create_additional_promo() + + response_list = self.client.get( + self.promo_list_url, + {'country': 'gb,FR,Kz', 'sort_by': 'active_from', 'limit': 10}, + ) + self.assertEqual( + response_list.status_code, + rest_framework.status.HTTP_200_OK, + ) + data = response_list.data + + self.assertEqual(len(data), 4) + self.assertEqual(response_list.get('X-Total-Count'), '4') + + def test_get_promos_filter_gb_kz_fr(self): + self._create_additional_promo() + response = self.client.get( + self.promo_list_url, + {'country': 'gb,Kz,FR', 'sort_by': 'active_from', 'limit': 10}, + ) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_200_OK, + ) + data = response.data + + self.assertEqual(len(data), 4) + self.assertEqual(response.get('X-Total-Count'), '4') + + def test_get_promos_filter_kz_sort_active_until(self): + self._create_additional_promo() + response = self.client.get( + self.promo_list_url, + {'country': 'Kz', 'sort_by': 'active_until', 'limit': 10}, + ) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_200_OK, + ) + data = response.data + + self.assertEqual(len(data), 2) + self.assertEqual(response.get('X-Total-Count'), '2') diff --git a/promo_code/business/tests/promocodes/validations/test_list_validation.py b/promo_code/business/tests/promocodes/validations/test_list_validation.py new file mode 100644 index 0000000..9d39055 --- /dev/null +++ b/promo_code/business/tests/promocodes/validations/test_list_validation.py @@ -0,0 +1,58 @@ +import parameterized +import rest_framework.status + +import business.tests.promocodes.base + + +class CompanyPromoFetchTests( + business.tests.promocodes.base.BasePromoCreateTestCase, +): + + def setUp(self): + super().setUp() + self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.token) + + @parameterized.parameterized.expand( + [ + ( + 'invalid_sort_by_format', + {'country': 'fr', 'sort_by': 'active_untilllll'}, + ), + ('invalid_country_format_single', {'country': 'france'}), + ('invalid_country_format_multiple', {'country': 'gb,us,france'}), + ('unexpected_parameter', {'unexpected': 'value'}), + ( + 'combined_invalid_parameters', + { + 'country': 'france', + 'limit': -1, + 'sort_by': 'non_existing_field', + }, + ), + ], + ) + def test_invalid_query_string_parameters(self, name, params): + response = self.client.get(self.promo_list_url, params) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_400_BAD_REQUEST, + ) + + @parameterized.parameterized.expand( + [ + ('invalid_limit_format', {'limit': 'france'}), + ('invalid_offset_format', {'offset': 'france'}), + ('negative_offset', {'offset': -5}), + ('negative_limit', {'limit': -5}), + ('invalid_float_limit', {'limit': 5.5}), + ('invalid_float_offset', {'offset': 3.5}), + ('empty_string_limit', {'limit': ''}), + ('empty_string_offset', {'offset': ''}), + ], + ) + def test_invalid_numeric_parameters(self, name, params): + response = self.client.get(self.promo_list_url, params) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_400_BAD_REQUEST, + ) diff --git a/promo_code/business/views.py b/promo_code/business/views.py index c64f3fc..0f260c6 100644 --- a/promo_code/business/views.py +++ b/promo_code/business/views.py @@ -23,6 +23,8 @@ class CompanySignUpView( core.views.BaseCustomResponseMixin, rest_framework.generics.CreateAPIView, ): + serializer_class = business.serializers.CompanySignUpSerializer + def post(self, request): try: serializer = business.serializers.CompanySignUpSerializer( @@ -65,6 +67,8 @@ class CompanySignInView( core.views.BaseCustomResponseMixin, rest_framework_simplejwt.views.TokenObtainPairView, ): + serializer_class = business.serializers.CompanySignInSerializer + def post(self, request): try: serializer = business.serializers.CompanySignInSerializer( @@ -171,8 +175,37 @@ def get_queryset(self): return queryset # noqa: R504 + def list(self, request, *args, **kwargs): + try: + self.validate_query_params() + except rest_framework.exceptions.ValidationError as e: + return rest_framework.response.Response( + e.detail, + status=rest_framework.status.HTTP_400_BAD_REQUEST, + ) + + return super().list(request, *args, **kwargs) + def validate_query_params(self): + self._validate_allowed_params() errors = {} + self._validate_countries(errors) + self._validate_sort_by(errors) + self._validate_offset() + self._validate_limit() + if errors: + raise rest_framework.exceptions.ValidationError(errors) + + def _validate_allowed_params(self): + allowed_params = {'country', 'limit', 'offset', 'sort_by'} + unexpected_params = ( + set(self.request.query_params.keys()) - allowed_params + ) + + if unexpected_params: + raise rest_framework.exceptions.ValidationError('Invalid params.') + + def _validate_countries(self, errors): countries = self.request.query_params.getlist('country', []) country_list = [] @@ -183,6 +216,10 @@ def validate_query_params(self): invalid_countries = [] for code in country_list: + if len(code) != 2: + invalid_countries.append(code) + continue + try: pycountry.countries.lookup(code) except LookupError: @@ -193,6 +230,7 @@ def validate_query_params(self): f'Invalid country codes: {", ".join(invalid_countries)}' ) + def _validate_sort_by(self, errors): sort_by = self.request.query_params.get('sort_by') if sort_by and sort_by not in ['active_from', 'active_until']: errors['sort_by'] = ( @@ -200,16 +238,33 @@ def validate_query_params(self): 'Available values: active_from, active_until' ) - if errors: - raise rest_framework.exceptions.ValidationError(errors) + def _validate_offset(self): + offset = self.request.query_params.get('offset') + if offset is not None: + try: + offset = int(offset) + except (TypeError, ValueError): + raise rest_framework.exceptions.ValidationError( + 'Invalid offset format.', + ) - def list(self, request, *args, **kwargs): - try: - self.validate_query_params() - except rest_framework.exceptions.ValidationError as e: - return rest_framework.response.Response( - e.detail, - status=rest_framework.status.HTTP_400_BAD_REQUEST, - ) + if offset < 0: + raise rest_framework.exceptions.ValidationError( + 'Offset cannot be negative.', + ) - return super().list(request, *args, **kwargs) + def _validate_limit(self): + limit = self.request.query_params.get('limit') + + if limit is not None: + try: + limit = int(limit) + except (TypeError, ValueError): + raise rest_framework.exceptions.ValidationError( + 'Invalid limit format.', + ) + + if limit < 0: + raise rest_framework.exceptions.ValidationError( + 'Limit cannot be negative.', + )