diff --git a/promo_code/business/tests/promocodes/base.py b/promo_code/business/tests/promocodes/base.py index 6a54d6c..37c74b9 100644 --- a/promo_code/business/tests/promocodes/base.py +++ b/promo_code/business/tests/promocodes/base.py @@ -23,9 +23,11 @@ def setUpTestData(cls): 'password': 'SecurePass123!', } business.models.Company.objects.create_company( - name=cls.valid_data['name'], + **cls.valid_data, + ) + + cls.company = business.models.Company.objects.get( email=cls.valid_data['email'], - password=cls.valid_data['password'], ) response = cls.client.post( diff --git a/promo_code/business/tests/promocodes/operations/test_list.py b/promo_code/business/tests/promocodes/operations/test_list.py index da94886..594ed8c 100644 --- a/promo_code/business/tests/promocodes/operations/test_list.py +++ b/promo_code/business/tests/promocodes/operations/test_list.py @@ -1,3 +1,4 @@ +import parameterized import rest_framework.status import rest_framework.test @@ -31,26 +32,6 @@ def _create_additional_promo(self): def setUpTestData(cls): business.tests.promocodes.base.BasePromoCreateTestCase.setUpTestData() - cls.valid_data = { - 'name': 'New Digital Marketing Solutions Inc.', - 'email': 'newtestcompany@example.com', - 'password': 'SecurePass123!', - } - - cls.company = business.models.Company.objects.create_company( - **cls.valid_data, - ) - - response = cls.client.post( - cls.signin_url, - { - 'email': cls.valid_data['email'], - 'password': cls.valid_data['password'], - }, - format='json', - ) - cls.new_token = response.data['access'] - cls.promo1_data = { 'description': 'Increased cashback 10% for new bank customers!', 'image_url': 'https://cdn2.thecatapi.com/images/3lo.jpg', @@ -113,7 +94,7 @@ def setUpTestData(cls): def setUp(self): self.client = rest_framework.test.APIClient() - self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.new_token) + self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.token) def test_get_promos_without_token(self): client = rest_framework.test.APIClient() @@ -311,3 +292,29 @@ def test_get_promos_filter_kz_sort_active_until(self): self.assertEqual(len(data), 2) self.assertEqual(response.get('X-Total-Count'), '2') + + @parameterized.parameterized.expand( + [ + ('comma_separated', {'country': 'gb,FR'}, 3), + ('multiple_params', {'country': ['gb', 'FR']}, 3), + ], + ) + def test_country_parameter_formats(self, _, params, expected_count): + full_params = { + **params, + 'sort_by': 'active_from', + 'limit': 10, + } + + response = self.client.get(self.promo_list_url, full_params) + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_200_OK, + ) + data = response.data + + self.assertEqual(len(data), expected_count) + self.assertEqual(data[0]['promo_id'], str(self.created_promos[1].id)) + self.assertEqual(data[1]['promo_id'], str(self.created_promos[0].id)) + self.assertEqual(data[2]['promo_id'], str(self.created_promos[2].id)) + self.assertEqual(response['X-Total-Count'], str(expected_count)) diff --git a/promo_code/business/tests/promocodes/validations/test_create_validation.py b/promo_code/business/tests/promocodes/validations/test_create_validation.py index 894f4f3..1156704 100644 --- a/promo_code/business/tests/promocodes/validations/test_create_validation.py +++ b/promo_code/business/tests/promocodes/validations/test_create_validation.py @@ -7,7 +7,13 @@ class TestPromoCodeCreation( business.tests.promocodes.base.BasePromoCreateTestCase, ): + + def setUp(self): + super().setUp() + self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.token) + def test_create_promo_with_old_token(self): + self.client.credentials() registration_data = { 'name': 'Someone', 'email': 'mail@mail.com', @@ -92,7 +98,6 @@ def test_missing_fields(self, name, payload): self.promo_create_url, payload, format='json', - HTTP_AUTHORIZATION='Bearer ' + self.token, ) self.assertEqual( response.status_code, @@ -113,7 +118,6 @@ def test_invalid_mode(self): self.promo_create_url, payload, format='json', - HTTP_AUTHORIZATION='Bearer ' + self.token, ) self.assertEqual( response.status_code, @@ -134,7 +138,6 @@ def test_invalid_max_count_for_unique_mode(self): self.promo_create_url, payload, format='json', - HTTP_AUTHORIZATION='Bearer ' + self.token, ) self.assertEqual( response.status_code, @@ -154,7 +157,6 @@ def test_short_description(self): self.promo_create_url, payload, format='json', - HTTP_AUTHORIZATION='Bearer ' + self.token, ) self.assertEqual( response.status_code, @@ -179,7 +181,6 @@ def test_invalid_country(self, invalid_country): self.promo_create_url, payload, format='json', - HTTP_AUTHORIZATION='Bearer ' + self.token, ) self.assertEqual( response.status_code, @@ -199,7 +200,6 @@ def test_nonexistent_country(self): self.promo_create_url, payload, format='json', - HTTP_AUTHORIZATION='Bearer ' + self.token, ) self.assertEqual( response.status_code, @@ -222,7 +222,6 @@ def test_invalid_age_range(self): self.promo_create_url, payload, format='json', - HTTP_AUTHORIZATION='Bearer ' + self.token, ) self.assertEqual( response.status_code, @@ -242,7 +241,6 @@ def test_common_with_promo_unique_provided(self): self.promo_create_url, payload, format='json', - HTTP_AUTHORIZATION='Bearer ' + self.token, ) self.assertEqual( response.status_code, @@ -262,7 +260,6 @@ def test_unique_with_promo_common_provided(self): self.promo_create_url, payload, format='json', - HTTP_AUTHORIZATION='Bearer ' + self.token, ) self.assertEqual( response.status_code, @@ -283,7 +280,6 @@ def test_both_promo_common_and_promo_unique_provided(self): self.promo_create_url, payload, format='json', - HTTP_AUTHORIZATION='Bearer ' + self.token, ) self.assertEqual( response.status_code, @@ -303,7 +299,6 @@ def test_too_short_promo_common(self): self.promo_create_url, payload, format='json', - HTTP_AUTHORIZATION='Bearer ' + self.token, ) self.assertEqual( response.status_code, @@ -358,7 +353,6 @@ def test_invalid_type_payloads(self, name, payload): self.promo_create_url, payload, format='json', - HTTP_AUTHORIZATION='Bearer ' + self.token, ) self.assertEqual( response.status_code, @@ -384,7 +378,6 @@ def test_invalid_max_count(self, max_count): self.promo_create_url, payload, format='json', - HTTP_AUTHORIZATION='Bearer ' + self.token, ) self.assertEqual( response.status_code, @@ -413,7 +406,6 @@ def test_invalid_image_url(self, url): self.promo_create_url, payload, format='json', - HTTP_AUTHORIZATION='Bearer ' + self.token, ) self.assertEqual( response.status_code, diff --git a/promo_code/business/tests/promocodes/validations/test_list_validation.py b/promo_code/business/tests/promocodes/validations/test_list_validation.py index 9d39055..b2a5a97 100644 --- a/promo_code/business/tests/promocodes/validations/test_list_validation.py +++ b/promo_code/business/tests/promocodes/validations/test_list_validation.py @@ -4,7 +4,7 @@ import business.tests.promocodes.base -class CompanyPromoFetchTests( +class TestPromoCodeList( business.tests.promocodes.base.BasePromoCreateTestCase, ): @@ -20,6 +20,10 @@ def setUp(self): ), ('invalid_country_format_single', {'country': 'france'}), ('invalid_country_format_multiple', {'country': 'gb,us,france'}), + ('invalid_country_does_not_exist', {'country': 'xx'}), + ('invalid_country_too_short', {'country': 'F'}), + ('invalid_country_format', {'country': 10}), + ('invalid_country_empty_string', {'country': ''}), ('unexpected_parameter', {'unexpected': 'value'}), ( 'combined_invalid_parameters', diff --git a/promo_code/business/views.py b/promo_code/business/views.py index 0f260c6..4de5f1b 100644 --- a/promo_code/business/views.py +++ b/promo_code/business/views.py @@ -152,16 +152,15 @@ def get_queryset(self): company=self.request.user, ) - countries = self.request.query_params.getlist('country', []) - country_list = [] - - for country_group in countries: - country_list.extend(country_group.split(',')) - - country_list = [c.strip() for c in country_list if c.strip()] - - if country_list: - regex_pattern = r'(' + '|'.join(map(re.escape, country_list)) + ')' + countries = [ + country.strip() + for group in self.request.query_params.getlist('country', []) + for country in group.split(',') + if country.strip() + ] + + if countries: + regex_pattern = r'(' + '|'.join(map(re.escape, countries)) + ')' queryset = queryset.filter( django.db.models.Q(target__country__iregex=regex_pattern) | django.db.models.Q(target__country__isnull=True), @@ -210,9 +209,17 @@ def _validate_countries(self, errors): country_list = [] for country_group in countries: - country_list.extend(country_group.split(',')) + parts = [part.strip() for part in country_group.split(',')] + + if any(part == '' for part in parts): + raise rest_framework.exceptions.ValidationError( + 'Invalid country format.', + ) + + country_list.extend(parts) country_list = [c.strip().upper() for c in country_list if c.strip()] + invalid_countries = [] for code in country_list: