Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions promo_code/business/tests/promocodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ def setUpTestData(cls):
'password': 'SecurePass123!',
}
business.models.Company.objects.create_company(
name=cls.valid_data['name'],
**cls.valid_data,
)

cls.company = business.models.Company.objects.get(
email=cls.valid_data['email'],
password=cls.valid_data['password'],
)

response = cls.client.post(
Expand Down
49 changes: 28 additions & 21 deletions promo_code/business/tests/promocodes/operations/test_list.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import parameterized
import rest_framework.status
import rest_framework.test

Expand Down Expand Up @@ -31,26 +32,6 @@ def _create_additional_promo(self):
def setUpTestData(cls):
business.tests.promocodes.base.BasePromoCreateTestCase.setUpTestData()

cls.valid_data = {
'name': 'New Digital Marketing Solutions Inc.',
'email': 'newtestcompany@example.com',
'password': 'SecurePass123!',
}

cls.company = business.models.Company.objects.create_company(
**cls.valid_data,
)

response = cls.client.post(
cls.signin_url,
{
'email': cls.valid_data['email'],
'password': cls.valid_data['password'],
},
format='json',
)
cls.new_token = response.data['access']

cls.promo1_data = {
'description': 'Increased cashback 10% for new bank customers!',
'image_url': 'https://cdn2.thecatapi.com/images/3lo.jpg',
Expand Down Expand Up @@ -113,7 +94,7 @@ def setUpTestData(cls):

def setUp(self):
self.client = rest_framework.test.APIClient()
self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.new_token)
self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.token)

def test_get_promos_without_token(self):
client = rest_framework.test.APIClient()
Expand Down Expand Up @@ -311,3 +292,29 @@ def test_get_promos_filter_kz_sort_active_until(self):

self.assertEqual(len(data), 2)
self.assertEqual(response.get('X-Total-Count'), '2')

@parameterized.parameterized.expand(
[
('comma_separated', {'country': 'gb,FR'}, 3),
('multiple_params', {'country': ['gb', 'FR']}, 3),
],
)
def test_country_parameter_formats(self, _, params, expected_count):
full_params = {
**params,
'sort_by': 'active_from',
'limit': 10,
}

response = self.client.get(self.promo_list_url, full_params)
self.assertEqual(
response.status_code,
rest_framework.status.HTTP_200_OK,
)
data = response.data

self.assertEqual(len(data), expected_count)
self.assertEqual(data[0]['promo_id'], str(self.created_promos[1].id))
self.assertEqual(data[1]['promo_id'], str(self.created_promos[0].id))
self.assertEqual(data[2]['promo_id'], str(self.created_promos[2].id))
self.assertEqual(response['X-Total-Count'], str(expected_count))
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
class TestPromoCodeCreation(
business.tests.promocodes.base.BasePromoCreateTestCase,
):

def setUp(self):
super().setUp()
self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.token)

def test_create_promo_with_old_token(self):
self.client.credentials()
registration_data = {
'name': 'Someone',
'email': 'mail@mail.com',
Expand Down Expand Up @@ -92,7 +98,6 @@ def test_missing_fields(self, name, payload):
self.promo_create_url,
payload,
format='json',
HTTP_AUTHORIZATION='Bearer ' + self.token,
)
self.assertEqual(
response.status_code,
Expand All @@ -113,7 +118,6 @@ def test_invalid_mode(self):
self.promo_create_url,
payload,
format='json',
HTTP_AUTHORIZATION='Bearer ' + self.token,
)
self.assertEqual(
response.status_code,
Expand All @@ -134,7 +138,6 @@ def test_invalid_max_count_for_unique_mode(self):
self.promo_create_url,
payload,
format='json',
HTTP_AUTHORIZATION='Bearer ' + self.token,
)
self.assertEqual(
response.status_code,
Expand All @@ -154,7 +157,6 @@ def test_short_description(self):
self.promo_create_url,
payload,
format='json',
HTTP_AUTHORIZATION='Bearer ' + self.token,
)
self.assertEqual(
response.status_code,
Expand All @@ -179,7 +181,6 @@ def test_invalid_country(self, invalid_country):
self.promo_create_url,
payload,
format='json',
HTTP_AUTHORIZATION='Bearer ' + self.token,
)
self.assertEqual(
response.status_code,
Expand All @@ -199,7 +200,6 @@ def test_nonexistent_country(self):
self.promo_create_url,
payload,
format='json',
HTTP_AUTHORIZATION='Bearer ' + self.token,
)
self.assertEqual(
response.status_code,
Expand All @@ -222,7 +222,6 @@ def test_invalid_age_range(self):
self.promo_create_url,
payload,
format='json',
HTTP_AUTHORIZATION='Bearer ' + self.token,
)
self.assertEqual(
response.status_code,
Expand All @@ -242,7 +241,6 @@ def test_common_with_promo_unique_provided(self):
self.promo_create_url,
payload,
format='json',
HTTP_AUTHORIZATION='Bearer ' + self.token,
)
self.assertEqual(
response.status_code,
Expand All @@ -262,7 +260,6 @@ def test_unique_with_promo_common_provided(self):
self.promo_create_url,
payload,
format='json',
HTTP_AUTHORIZATION='Bearer ' + self.token,
)
self.assertEqual(
response.status_code,
Expand All @@ -283,7 +280,6 @@ def test_both_promo_common_and_promo_unique_provided(self):
self.promo_create_url,
payload,
format='json',
HTTP_AUTHORIZATION='Bearer ' + self.token,
)
self.assertEqual(
response.status_code,
Expand All @@ -303,7 +299,6 @@ def test_too_short_promo_common(self):
self.promo_create_url,
payload,
format='json',
HTTP_AUTHORIZATION='Bearer ' + self.token,
)
self.assertEqual(
response.status_code,
Expand Down Expand Up @@ -358,7 +353,6 @@ def test_invalid_type_payloads(self, name, payload):
self.promo_create_url,
payload,
format='json',
HTTP_AUTHORIZATION='Bearer ' + self.token,
)
self.assertEqual(
response.status_code,
Expand All @@ -384,7 +378,6 @@ def test_invalid_max_count(self, max_count):
self.promo_create_url,
payload,
format='json',
HTTP_AUTHORIZATION='Bearer ' + self.token,
)
self.assertEqual(
response.status_code,
Expand Down Expand Up @@ -413,7 +406,6 @@ def test_invalid_image_url(self, url):
self.promo_create_url,
payload,
format='json',
HTTP_AUTHORIZATION='Bearer ' + self.token,
)
self.assertEqual(
response.status_code,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import business.tests.promocodes.base


class CompanyPromoFetchTests(
class TestPromoCodeList(
business.tests.promocodes.base.BasePromoCreateTestCase,
):

Expand All @@ -20,6 +20,10 @@ def setUp(self):
),
('invalid_country_format_single', {'country': 'france'}),
('invalid_country_format_multiple', {'country': 'gb,us,france'}),
('invalid_country_does_not_exist', {'country': 'xx'}),
('invalid_country_too_short', {'country': 'F'}),
('invalid_country_format', {'country': 10}),
('invalid_country_empty_string', {'country': ''}),
('unexpected_parameter', {'unexpected': 'value'}),
(
'combined_invalid_parameters',
Expand Down
29 changes: 18 additions & 11 deletions promo_code/business/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,15 @@ def get_queryset(self):
company=self.request.user,
)

countries = self.request.query_params.getlist('country', [])
country_list = []

for country_group in countries:
country_list.extend(country_group.split(','))

country_list = [c.strip() for c in country_list if c.strip()]

if country_list:
regex_pattern = r'(' + '|'.join(map(re.escape, country_list)) + ')'
countries = [
country.strip()
for group in self.request.query_params.getlist('country', [])
for country in group.split(',')
if country.strip()
]

if countries:
regex_pattern = r'(' + '|'.join(map(re.escape, countries)) + ')'
queryset = queryset.filter(
django.db.models.Q(target__country__iregex=regex_pattern)
| django.db.models.Q(target__country__isnull=True),
Expand Down Expand Up @@ -210,9 +209,17 @@ def _validate_countries(self, errors):
country_list = []

for country_group in countries:
country_list.extend(country_group.split(','))
parts = [part.strip() for part in country_group.split(',')]

if any(part == '' for part in parts):
raise rest_framework.exceptions.ValidationError(
'Invalid country format.',
)

country_list.extend(parts)

country_list = [c.strip().upper() for c in country_list if c.strip()]

invalid_countries = []

for code in country_list:
Expand Down