diff --git a/promo_code/core/views.py b/promo_code/core/views.py index 50fad55..e6f149b 100644 --- a/promo_code/core/views.py +++ b/promo_code/core/views.py @@ -2,20 +2,9 @@ import django.views import rest_framework.permissions import rest_framework.response -import rest_framework.status import rest_framework.views -class BaseCustomResponseMixin: - error_response = {'status': 'error', 'message': 'Error in request data.'} - - def handle_validation_error(self): - return rest_framework.response.Response( - self.error_response, - status=rest_framework.status.HTTP_400_BAD_REQUEST, - ) - - class PingView(django.views.View): def get(self, request, *args, **kwargs): return django.http.HttpResponse('PROOOOOOOOOOOOOOOOOD', status=200) diff --git a/promo_code/user/constants.py b/promo_code/user/constants.py new file mode 100644 index 0000000..9681d90 --- /dev/null +++ b/promo_code/user/constants.py @@ -0,0 +1,18 @@ +AGE_MIN = 0 +AGE_MAX = 100 + +COUNTRY_CODE_LENGTH = 2 + +PASSWORD_MIN_LENGTH = 8 +PASSWORD_MAX_LENGTH = 60 + +NAME_MIN_LENGTH = 1 +NAME_MAX_LENGTH = 100 + +SURNAME_MIN_LENGTH = 1 +SURNAME_MAX_LENGTH = 120 + +EMAIL_MIN_LENGTH = 8 +EMAIL_MAX_LENGTH = 120 + +AVATAR_URL_MAX_LENGTH = 350 diff --git a/promo_code/user/models.py b/promo_code/user/models.py index dacda78..c926ae2 100644 --- a/promo_code/user/models.py +++ b/promo_code/user/models.py @@ -2,6 +2,8 @@ import django.db.models import django.utils.timezone +import user.constants + class UserManager(django.contrib.auth.models.BaseUserManager): def create_user(self, email, name, surname, password=None, **extra_fields): @@ -36,13 +38,20 @@ class User( django.contrib.auth.models.AbstractBaseUser, django.contrib.auth.models.PermissionsMixin, ): - email = django.db.models.EmailField(unique=True, max_length=120) - name = django.db.models.CharField(max_length=100) - surname = django.db.models.CharField(max_length=120) + email = django.db.models.EmailField( + unique=True, + max_length=user.constants.EMAIL_MAX_LENGTH, + ) + name = django.db.models.CharField( + max_length=user.constants.NAME_MAX_LENGTH, + ) + surname = django.db.models.CharField( + max_length=user.constants.SURNAME_MAX_LENGTH, + ) avatar_url = django.db.models.URLField( blank=True, null=True, - max_length=350, + max_length=user.constants.AVATAR_URL_MAX_LENGTH, ) other = django.db.models.JSONField(default=dict) diff --git a/promo_code/user/serializers.py b/promo_code/user/serializers.py index ee91085..b395f22 100644 --- a/promo_code/user/serializers.py +++ b/promo_code/user/serializers.py @@ -1,40 +1,67 @@ import django.contrib.auth.password_validation import django.core.exceptions import django.core.validators +import django.db.models +import pycountry import rest_framework.exceptions import rest_framework.serializers -import rest_framework.status import rest_framework_simplejwt.serializers import rest_framework_simplejwt.token_blacklist.models as tb_models import rest_framework_simplejwt.tokens +import user.constants import user.models as user_models import user.validators +class OtherFieldSerializer(rest_framework.serializers.Serializer): + age = rest_framework.serializers.IntegerField( + required=True, + min_value=user.constants.AGE_MIN, + max_value=user.constants.AGE_MAX, + ) + country = rest_framework.serializers.CharField( + required=True, + max_length=user.constants.COUNTRY_CODE_LENGTH, + min_length=user.constants.COUNTRY_CODE_LENGTH, + ) + + def validate(self, value): + country = value['country'].upper() + + try: + pycountry.countries.lookup(country) + except LookupError: + raise rest_framework.serializers.ValidationError( + 'Invalid ISO 3166-1 alpha-2 country code.', + ) + + return value + + class SignUpSerializer(rest_framework.serializers.ModelSerializer): password = rest_framework.serializers.CharField( write_only=True, required=True, validators=[django.contrib.auth.password_validation.validate_password], - max_length=60, - min_length=8, + max_length=user.constants.PASSWORD_MAX_LENGTH, + min_length=user.constants.PASSWORD_MIN_LENGTH, style={'input_type': 'password'}, ) name = rest_framework.serializers.CharField( required=True, - min_length=1, - max_length=100, + min_length=user.constants.NAME_MIN_LENGTH, + max_length=user.constants.NAME_MAX_LENGTH, ) surname = rest_framework.serializers.CharField( required=True, - min_length=1, - max_length=120, + min_length=user.constants.SURNAME_MIN_LENGTH, + max_length=user.constants.SURNAME_MAX_LENGTH, ) email = rest_framework.serializers.EmailField( required=True, - min_length=8, - max_length=120, + min_length=user.constants.EMAIL_MIN_LENGTH, + max_length=user.constants.EMAIL_MAX_LENGTH, validators=[ user.validators.UniqueEmailValidator( 'This email address is already registered.', @@ -44,15 +71,12 @@ class SignUpSerializer(rest_framework.serializers.ModelSerializer): ) avatar_url = rest_framework.serializers.CharField( required=False, - max_length=350, + max_length=user.constants.AVATAR_URL_MAX_LENGTH, validators=[ django.core.validators.URLValidator(schemes=['http', 'https']), ], ) - other = rest_framework.serializers.JSONField( - required=True, - validators=[user.validators.OtherFieldValidator()], - ) + other = OtherFieldSerializer(required=True) class Meta: model = user_models.User @@ -94,13 +118,14 @@ class SignInSerializer( def validate(self, attrs): user = self.authenticate_user(attrs) - self.update_token_version(user) + user.token_version = django.db.models.F('token_version') + 1 + user.save(update_fields=['token_version']) data = super().validate(attrs) refresh = rest_framework_simplejwt.tokens.RefreshToken(data['refresh']) - self.invalidate_previous_tokens(user, refresh['jti']) + self.blacklist_other_tokens(user, refresh['jti']) return data @@ -128,19 +153,18 @@ def authenticate_user(self, attrs): return user - def invalidate_previous_tokens(self, user, current_jti): - outstanding_tokens = tb_models.OutstandingToken.objects.filter( - user=user, - ).exclude(jti=current_jti) - - for token in outstanding_tokens: - tb_models.BlacklistedToken.objects.get_or_create(token=token) - - def update_token_version(self, user): - user.token_version += 1 - user.save() + def blacklist_other_tokens(self, user, current_jti): + qs = tb_models.OutstandingToken.objects.filter(user=user).exclude( + jti=current_jti, + ) + blacklisted = [tb_models.BlacklistedToken(token=tok) for tok in qs] + tb_models.BlacklistedToken.objects.bulk_create( + blacklisted, + ignore_conflicts=True, + ) - def get_token(self, user): + @classmethod + def get_token(cls, user): token = super().get_token(user) token['token_version'] = user.token_version return token diff --git a/promo_code/user/urls.py b/promo_code/user/urls.py index c2bc733..2d0e27b 100644 --- a/promo_code/user/urls.py +++ b/promo_code/user/urls.py @@ -9,12 +9,12 @@ urlpatterns = [ django.urls.path( 'auth/sign-up', - user.views.SignUpView.as_view(), + user.views.UserSignUpView.as_view(), name='sign-up', ), django.urls.path( 'auth/sign-in', - rest_framework_simplejwt.views.TokenObtainPairView.as_view(), + user.views.UserSignInView.as_view(), name='sign-in', ), django.urls.path( diff --git a/promo_code/user/validators.py b/promo_code/user/validators.py index 2a62b5e..cbd8941 100644 --- a/promo_code/user/validators.py +++ b/promo_code/user/validators.py @@ -1,6 +1,4 @@ -import pycountry import rest_framework.exceptions -import rest_framework.serializers import user.models @@ -24,69 +22,3 @@ def __call__(self, value): ) exc.status_code = self.status_code raise exc - - -class OtherFieldValidator(rest_framework.serializers.Serializer): - """ - Validates JSON fields: - - age (required, 0-100) - - country (required, valid ISO 3166-1 alpha-2) - """ - - country_codes = {c.alpha_2 for c in pycountry.countries} - - age = rest_framework.serializers.IntegerField( - required=True, - min_value=0, - max_value=100, - error_messages={ - 'required': 'This field is required.', - 'invalid': 'Must be an integer.', - 'min_value': 'Must be between 0 and 100.', - 'max_value': 'Must be between 0 and 100.', - }, - ) - - country = rest_framework.serializers.CharField( - required=True, - max_length=2, - min_length=2, - error_messages={ - 'required': 'This field is required.', - 'blank': 'Must be a 2-letter ISO code.', - 'max_length': 'Must be a 2-letter ISO code.', - 'min_length': 'Must be a 2-letter ISO code.', - }, - ) - - def validate_country(self, value): - country = value.upper() - if country not in self.country_codes: - raise rest_framework.serializers.ValidationError( - 'Invalid ISO 3166-1 alpha-2 country code.', - ) - - return country - - def __call__(self, value): - if not isinstance(value, dict): - raise rest_framework.serializers.ValidationError( - {'non_field_errors': ['Must be a JSON object']}, - ) - - missing_fields = [ - field - for field in self.fields - if field not in value or value.get(field) in (None, '') - ] - - if missing_fields: - raise rest_framework.serializers.ValidationError( - dict.fromkeys(missing_fields, 'This field is required.'), - ) - - serializer = self.__class__(data=value) - if not serializer.is_valid(): - raise rest_framework.serializers.ValidationError(serializer.errors) - - return value diff --git a/promo_code/user/views.py b/promo_code/user/views.py index bf52e72..5afc44a 100644 --- a/promo_code/user/views.py +++ b/promo_code/user/views.py @@ -1,67 +1,39 @@ -import rest_framework.exceptions import rest_framework.generics import rest_framework.response -import rest_framework.serializers import rest_framework.status -import rest_framework_simplejwt.exceptions import rest_framework_simplejwt.tokens import rest_framework_simplejwt.views -import core.views import user.serializers -class SignUpView( - core.views.BaseCustomResponseMixin, +class UserSignUpView( rest_framework.generics.CreateAPIView, ): serializer_class = user.serializers.SignUpSerializer def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) - - try: - serializer.is_valid(raise_exception=True) - except rest_framework.exceptions.ValidationError: - return self.handle_validation_error() + serializer.is_valid(raise_exception=True) user = serializer.save() - refresh = rest_framework_simplejwt.tokens.RefreshToken.for_user(user) refresh['token_version'] = user.token_version + access_token = refresh.access_token + response_data = { + 'access': str(access_token), + 'refresh': str(refresh), + } + return rest_framework.response.Response( - {'access': str(access_token), 'refresh': str(refresh)}, + response_data, status=rest_framework.status.HTTP_200_OK, ) -class SignInView( - core.views.BaseCustomResponseMixin, +class UserSignInView( rest_framework_simplejwt.views.TokenObtainPairView, ): serializer_class = user.serializers.SignInSerializer - - def post(self, request, *args, **kwargs): - try: - serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - except ( - rest_framework.serializers.ValidationError, - rest_framework_simplejwt.exceptions.TokenError, - ) as e: - if isinstance(e, rest_framework.serializers.ValidationError): - return self.handle_validation_error() - - raise rest_framework_simplejwt.exceptions.InvalidToken(str(e)) - - response_data = { - 'access': serializer.validated_data['access'], - 'refresh': serializer.validated_data['refresh'], - } - - return rest_framework.response.Response( - response_data, - status=rest_framework.status.HTTP_200_OK, - )