diff --git a/api/custom_auth/oauth/serializers.py b/api/custom_auth/oauth/serializers.py index 6a1e80ab90af..cf16c008191b 100644 --- a/api/custom_auth/oauth/serializers.py +++ b/api/custom_auth/oauth/serializers.py @@ -6,13 +6,11 @@ from django.db.models import F from rest_framework import serializers from rest_framework.authtoken.models import Token -from rest_framework.exceptions import PermissionDenied -from organisations.invites.models import Invite from users.auth_type import AuthType from users.models import SignUpType -from ..constants import USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE +from ..serializers import InviteLinkValidationMixin from .github import GithubUser from .google import get_user_info @@ -20,7 +18,7 @@ UserModel = get_user_model() -class OAuthLoginSerializer(serializers.Serializer): +class OAuthLoginSerializer(InviteLinkValidationMixin, serializers.Serializer): access_token = serializers.CharField( required=True, help_text="Code or access token returned from the FE interaction with the third party login provider.", @@ -85,12 +83,9 @@ def _get_user(self, user_data: dict): if not existing_user: sign_up_type = self.validated_data.get("sign_up_type") - if not ( - settings.ALLOW_REGISTRATION_WITHOUT_INVITE - or sign_up_type == SignUpType.INVITE_LINK.value - or Invite.objects.filter(email=email).exists() - ): - raise PermissionDenied(USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE) + self._validate_registration_invite( + email=email, sign_up_type=self.validated_data.get("sign_up_type") + ) return UserModel.objects.create( **user_data, email=email.lower(), sign_up_type=sign_up_type diff --git a/api/custom_auth/serializers.py b/api/custom_auth/serializers.py index 55bb43e595ae..11bd80828a6b 100644 --- a/api/custom_auth/serializers.py +++ b/api/custom_auth/serializers.py @@ -5,7 +5,7 @@ from rest_framework.exceptions import PermissionDenied from rest_framework.validators import UniqueValidator -from organisations.invites.models import Invite +from organisations.invites.models import Invite, InviteLink from users.auth_type import AuthType from users.constants import DEFAULT_DELETE_ORPHAN_ORGANISATIONS_VALUE from users.models import FFAdminUser, SignUpType @@ -23,7 +23,28 @@ class Meta: fields = ("key",) -class CustomUserCreateSerializer(UserCreateSerializer): +class InviteLinkValidationMixin: + invite_hash = serializers.CharField(required=False, write_only=True) + + def _validate_registration_invite(self, email: str, sign_up_type: str) -> None: + if settings.ALLOW_REGISTRATION_WITHOUT_INVITE: + return + + valid = False + + match sign_up_type: + case SignUpType.INVITE_LINK.value: + valid = InviteLink.objects.filter( + hash=self.initial_data.get("invite_hash") + ).exists() + case SignUpType.INVITE_EMAIL.value: + valid = Invite.objects.filter(email__iexact=email.lower()).exists() + + if not valid: + raise PermissionDenied(USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE) + + +class CustomUserCreateSerializer(UserCreateSerializer, InviteLinkValidationMixin): key = serializers.SerializerMethodField() class Meta(UserCreateSerializer.Meta): @@ -58,6 +79,10 @@ def validate(self, attrs): self.context.get("request"), email=email, raise_exception=True ) + self._validate_registration_invite( + email=email, sign_up_type=attrs.get("sign_up_type") + ) + attrs["email"] = email.lower() return attrs @@ -66,16 +91,6 @@ def get_key(instance): token, _ = Token.objects.get_or_create(user=instance) return token.key - def save(self, **kwargs): - if not ( - settings.ALLOW_REGISTRATION_WITHOUT_INVITE - or self.validated_data.get("sign_up_type") == SignUpType.INVITE_LINK.value - or Invite.objects.filter(email=self.validated_data.get("email")) - ): - raise PermissionDenied(USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE) - - return super(CustomUserCreateSerializer, self).save(**kwargs) - class CustomUserDelete(serializers.Serializer): current_password = serializers.CharField( diff --git a/api/tests/integration/custom_auth/end_to_end/test_custom_auth_integration.py b/api/tests/integration/custom_auth/end_to_end/test_custom_auth_integration.py index aec4b9b4327f..e274475119b5 100644 --- a/api/tests/integration/custom_auth/end_to_end/test_custom_auth_integration.py +++ b/api/tests/integration/custom_auth/end_to_end/test_custom_auth_integration.py @@ -12,7 +12,7 @@ from organisations.invites.models import Invite from organisations.models import Organisation -from users.models import FFAdminUser +from users.models import FFAdminUser, SignUpType def test_register_and_login_workflows(db: None, api_client: APIClient) -> None: @@ -124,6 +124,7 @@ def test_can_register_with_invite_if_registration_disabled_without_invite( "password": password, "first_name": "test", "last_name": "register", + "sign_up_type": SignUpType.INVITE_EMAIL.value, } Invite.objects.create(email=email, organisation=organisation) diff --git a/api/tests/unit/custom_auth/conftest.py b/api/tests/unit/custom_auth/conftest.py new file mode 100644 index 000000000000..17d5f760c4c1 --- /dev/null +++ b/api/tests/unit/custom_auth/conftest.py @@ -0,0 +1,9 @@ +import pytest + +from organisations.invites.models import InviteLink +from organisations.models import Organisation + + +@pytest.fixture() +def invite_link(organisation: Organisation) -> InviteLink: + return InviteLink.objects.create(organisation=organisation) diff --git a/api/tests/unit/custom_auth/oauth/test_unit_oauth_serializers.py b/api/tests/unit/custom_auth/oauth/test_unit_oauth_serializers.py index bd21e9fc5d08..11a0519e0b6f 100644 --- a/api/tests/unit/custom_auth/oauth/test_unit_oauth_serializers.py +++ b/api/tests/unit/custom_auth/oauth/test_unit_oauth_serializers.py @@ -1,5 +1,7 @@ +from typing import Type from unittest import mock +import pytest from django.test import RequestFactory from django.utils import timezone from pytest_django.fixtures import SettingsWrapper @@ -11,6 +13,7 @@ GoogleLoginSerializer, OAuthLoginSerializer, ) +from organisations.invites.models import InviteLink from users.models import FFAdminUser, SignUpType @@ -128,7 +131,11 @@ def test_OAuthLoginSerializer_calls_is_authentication_method_valid_correctly_if_ def test_OAuthLoginSerializer_allows_registration_if_sign_up_type_is_invite_link( - settings: SettingsWrapper, rf: RequestFactory, mocker: MockerFixture, db: None + settings: SettingsWrapper, + rf: RequestFactory, + mocker: MockerFixture, + db: None, + invite_link: InviteLink, ): # Given settings.ALLOW_REGISTRATION_WITHOUT_INVITE = False @@ -140,6 +147,7 @@ def test_OAuthLoginSerializer_allows_registration_if_sign_up_type_is_invite_link data={ "access_token": "some_token", "sign_up_type": SignUpType.INVITE_LINK.value, + "invite_hash": invite_link.hash, }, context={"request": request}, ) @@ -153,3 +161,38 @@ def test_OAuthLoginSerializer_allows_registration_if_sign_up_type_is_invite_link # Then assert user + + +@pytest.mark.parametrize( + "serializer_class", (GithubLoginSerializer, GithubLoginSerializer) +) +def test_OAuthLoginSerializer_allows_login_if_allow_registration_without_invite_is_false( + settings: SettingsWrapper, + rf: RequestFactory, + mocker: MockerFixture, + admin_user: FFAdminUser, + serializer_class: Type[OAuthLoginSerializer], +): + # Given + settings.ALLOW_REGISTRATION_WITHOUT_INVITE = False + + request = rf.post("/api/v1/auth/users/") + + serializer = serializer_class( + data={"access_token": "some_token"}, + context={"request": request}, + ) + # monkey patch the get_user_info method to return the mock user data + serializer.get_user_info = lambda: { + "email": admin_user.email, + "github_user_id": "abc123", + "google_user_id": "abc123", + } + + serializer.is_valid(raise_exception=True) + + # When + user = serializer.save() + + # Then + assert user diff --git a/api/tests/unit/custom_auth/oauth/test_unit_oauth_views.py b/api/tests/unit/custom_auth/oauth/test_unit_oauth_views.py index 0f742267b71b..99a451bab4eb 100644 --- a/api/tests/unit/custom_auth/oauth/test_unit_oauth_views.py +++ b/api/tests/unit/custom_auth/oauth/test_unit_oauth_views.py @@ -9,6 +9,7 @@ from organisations.invites.models import Invite from organisations.models import Organisation +from users.models import SignUpType @mock.patch("custom_auth.oauth.serializers.get_user_info") @@ -66,7 +67,13 @@ def test_can_register_with_google_with_invite_if_registration_disabled( Invite.objects.create(organisation=organisation, email=email) # When - response = client.post(url, data={"access_token": "some-token"}) + response = client.post( + url, + data={ + "access_token": "some-token", + "sign_up_type": SignUpType.INVITE_EMAIL.value, + }, + ) # Then assert response.status_code == status.HTTP_200_OK @@ -89,7 +96,13 @@ def test_can_register_with_github_with_invite_if_registration_disabled( Invite.objects.create(organisation=organisation, email=email) # When - response = client.post(url, data={"access_token": "some-token"}) + response = client.post( + url, + data={ + "access_token": "some-token", + "sign_up_type": SignUpType.INVITE_EMAIL.value, + }, + ) # Then assert response.status_code == status.HTTP_200_OK diff --git a/api/tests/unit/custom_auth/test_unit_custom_auth_serializer.py b/api/tests/unit/custom_auth/test_unit_custom_auth_serializer.py index 00f099e1ace6..010a861f30ab 100644 --- a/api/tests/unit/custom_auth/test_unit_custom_auth_serializer.py +++ b/api/tests/unit/custom_auth/test_unit_custom_auth_serializer.py @@ -1,7 +1,13 @@ +import pytest from django.test import RequestFactory from pytest_django.fixtures import SettingsWrapper +from rest_framework.exceptions import PermissionDenied +from custom_auth.constants import ( + USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE, +) from custom_auth.serializers import CustomUserCreateSerializer +from organisations.invites.models import InviteLink from users.models import FFAdminUser, SignUpType user_dict = { @@ -70,6 +76,7 @@ def test_CustomUserCreateSerializer_calls_is_authentication_method_valid_correct def test_CustomUserCreateSerializer_allows_registration_if_sign_up_type_is_invite_link( + invite_link: InviteLink, db: None, settings: SettingsWrapper, rf: RequestFactory, @@ -80,6 +87,7 @@ def test_CustomUserCreateSerializer_allows_registration_if_sign_up_type_is_invit data = { **user_dict, "sign_up_type": SignUpType.INVITE_LINK.value, + "invite_hash": invite_link.hash, } serializer = CustomUserCreateSerializer( @@ -92,3 +100,48 @@ def test_CustomUserCreateSerializer_allows_registration_if_sign_up_type_is_invit # Then assert user + + +def test_invite_link_validation_mixin_validate_fails_if_invite_link_hash_not_provided( + settings: SettingsWrapper, + db: None, +) -> None: + # Given + settings.ALLOW_REGISTRATION_WITHOUT_INVITE = False + + serializer = CustomUserCreateSerializer( + data={ + **user_dict, + "sign_up_type": SignUpType.INVITE_LINK.value, + } + ) + + # When + with pytest.raises(PermissionDenied) as exc_info: + serializer.is_valid(raise_exception=True) + + # Then + assert exc_info.value.detail == USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE + + +def test_invite_link_validation_mixin_validate_fails_if_invite_link_hash_not_valid( + invite_link: InviteLink, + settings: SettingsWrapper, +) -> None: + # Given + settings.ALLOW_REGISTRATION_WITHOUT_INVITE = False + + serializer = CustomUserCreateSerializer( + data={ + **user_dict, + "sign_up_type": SignUpType.INVITE_LINK.value, + "invite_hash": "invalid-hash", + } + ) + + # When + with pytest.raises(PermissionDenied) as exc_info: + serializer.is_valid(raise_exception=True) + + # Then + assert exc_info.value.detail == USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE