diff --git a/api/custom_auth/oauth/serializers.py b/api/custom_auth/oauth/serializers.py index 63e758b127bf..6a1e80ab90af 100644 --- a/api/custom_auth/oauth/serializers.py +++ b/api/custom_auth/oauth/serializers.py @@ -1,11 +1,15 @@ +from abc import abstractmethod + from django.conf import settings from django.contrib.auth import get_user_model from django.contrib.auth.signals import user_logged_in +from django.db.models import F from rest_framework import serializers from rest_framework.authtoken.models import Token from rest_framework.exceptions import PermissionDenied from organisations.invites.models import Invite +from users.auth_type import AuthType from users.models import SignUpType from ..constants import USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE @@ -30,6 +34,9 @@ class OAuthLoginSerializer(serializers.Serializer): write_only=True, ) + auth_type: AuthType | None = None + user_model_id_attribute: str = "id" + class Meta: abstract = True @@ -53,8 +60,28 @@ def create(self, validated_data): return Token.objects.get_or_create(user=user)[0] def _get_user(self, user_data: dict): - email = user_data.get("email") - existing_user = UserModel.objects.filter(email=email).first() + email: str = user_data.pop("email") + + # There are a number of scenarios that we're catering for in this + # query: + # 1. A new user arriving, and immediately authenticating with + # the given social auth method. + # 2. A user that has previously authenticated with method A is now + # authenticating with method B. Using the `email__iexact` means + # that we'll always retrieve the user that already authenticated + # with A. + # 3. A user that (prior to the case sensitivity fix) authenticated + # with multiple methods and ended up with duplicate user accounts. + # Since it's difficult for us to know which user account they are + # using as their primary, we order by the method they are currently + # authenticating with and grab the first one in the list. + existing_user = ( + UserModel.objects.filter(email__iexact=email) + .order_by( + F(self.user_model_id_attribute).desc(nulls_last=True), + ) + .first() + ) if not existing_user: sign_up_type = self.validated_data.get("sign_up_type") @@ -65,20 +92,46 @@ def _get_user(self, user_data: dict): ): raise PermissionDenied(USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE) - return UserModel.objects.create(**user_data, sign_up_type=sign_up_type) + return UserModel.objects.create( + **user_data, email=email.lower(), sign_up_type=sign_up_type + ) + elif existing_user.auth_type != self.get_auth_type().value: + # In this scenario, we're seeing a user that had previously + # authenticated with another authentication method and is now + # authenticating with a new OAuth provider. + setattr( + existing_user, + self.user_model_id_attribute, + user_data[self.user_model_id_attribute], + ) + existing_user.save() return existing_user + @abstractmethod def get_user_info(self): raise NotImplementedError("`get_user_info()` must be implemented.") + def get_auth_type(self) -> AuthType: + if not self.auth_type: # pragma: no cover + raise NotImplementedError( + "`auth_type` must be set, or `get_auth_type()` must be implemented." + ) + return self.auth_type + class GoogleLoginSerializer(OAuthLoginSerializer): + auth_type = AuthType.GOOGLE + user_model_id_attribute = "google_user_id" + def get_user_info(self): return get_user_info(self.validated_data["access_token"]) class GithubLoginSerializer(OAuthLoginSerializer): + auth_type = AuthType.GITHUB + user_model_id_attribute = "github_user_id" + def get_user_info(self): github_user = GithubUser(code=self.validated_data["access_token"]) return github_user.get_user_info() diff --git a/api/tests/unit/custom_auth/oauth/test_unit_oauth_views.py b/api/tests/unit/custom_auth/oauth/test_unit_oauth_views.py index 2a8895de490c..0f742267b71b 100644 --- a/api/tests/unit/custom_auth/oauth/test_unit_oauth_views.py +++ b/api/tests/unit/custom_auth/oauth/test_unit_oauth_views.py @@ -1,7 +1,9 @@ from unittest import mock +from django.db.models import Model from django.test import override_settings from django.urls import reverse +from pytest_mock import MockerFixture from rest_framework import status from rest_framework.test import APIClient @@ -103,7 +105,12 @@ def test_can_login_with_google_if_registration_disabled( client = APIClient() email = "test@example.com" - mock_get_user_info.return_value = {"email": email} + mock_get_user_info.return_value = { + "email": email, + "first_name": "John", + "last_name": "Smith", + "google_user_id": "abc123", + } django_user_model.objects.create(email=email) # When @@ -126,7 +133,12 @@ def test_can_login_with_github_if_registration_disabled( email = "test@example.com" mock_github_user = mock.MagicMock() MockGithubUser.return_value = mock_github_user - mock_github_user.get_user_info.return_value = {"email": email} + mock_github_user.get_user_info.return_value = { + "email": email, + "first_name": "John", + "last_name": "Smith", + "github_user_id": "abc123", + } django_user_model.objects.create(email=email) # When @@ -135,3 +147,144 @@ def test_can_login_with_github_if_registration_disabled( # Then assert response.status_code == status.HTTP_200_OK assert "key" in response.json() + + +def test_login_with_google_updates_existing_user_case_insensitive( + db: None, + django_user_model: type[Model], + mocker: MockerFixture, + api_client: APIClient, +) -> None: + # Given + email_lower = "test@example.com" + email_upper = email_lower.upper() + google_user_id = "abc123" + + django_user_model.objects.create(email=email_lower) + + mocker.patch( + "custom_auth.oauth.serializers.get_user_info", + return_value={ + "email": email_upper, + "first_name": "John", + "last_name": "Smith", + "google_user_id": google_user_id, + }, + ) + + url = reverse("api-v1:custom_auth:oauth:google-oauth-login") + + # When + response = api_client.post(url, data={"access_token": "some-token"}) + + # Then + assert response.status_code == status.HTTP_200_OK + + qs = django_user_model.objects.filter(email__iexact=email_lower) + assert qs.count() == 1 + + user = qs.first() + assert user.email == email_lower + assert user.google_user_id == google_user_id + + +def test_login_with_github_updates_existing_user_case_insensitive( + db: None, + django_user_model: type[Model], + mocker: MockerFixture, + api_client: APIClient, +) -> None: + # Given + email_lower = "test@example.com" + email_upper = email_lower.upper() + github_user_id = "abc123" + + django_user_model.objects.create(email=email_lower) + + mock_github_user = mock.MagicMock() + mocker.patch( + "custom_auth.oauth.serializers.GithubUser", return_value=mock_github_user + ) + mock_github_user.get_user_info.return_value = { + "email": email_upper, + "first_name": "John", + "last_name": "Smith", + "github_user_id": github_user_id, + } + + url = reverse("api-v1:custom_auth:oauth:github-oauth-login") + + # When + response = api_client.post(url, data={"access_token": "some-token"}) + + # Then + assert response.status_code == status.HTTP_200_OK + + qs = django_user_model.objects.filter(email__iexact=email_lower) + assert qs.count() == 1 + + user = qs.first() + assert user.email == email_lower + assert user.github_user_id == github_user_id + + +def test_user_with_duplicate_accounts_authenticates_as_the_correct_oauth_user( + db: None, + django_user_model: type[Model], + api_client: APIClient, + mocker: MockerFixture, +) -> None: + """ + Specific test to verify the correct behaviour for users affected by + https://github.com/Flagsmith/flagsmith/issues/4185. + """ + + # Given + email_lower = "test@example.com" + email_upper = email_lower.upper() + + github_user = django_user_model.objects.create( + email=email_lower, github_user_id="abc123" + ) + google_user = django_user_model.objects.create( + email=email_upper, google_user_id="abc123" + ) + + mock_github_user = mock.MagicMock() + mocker.patch( + "custom_auth.oauth.serializers.GithubUser", return_value=mock_github_user + ) + mock_github_user.get_user_info.return_value = { + "email": email_lower, + "first_name": "John", + "last_name": "Smith", + "github_user_id": github_user.github_user_id, + } + + mocker.patch( + "custom_auth.oauth.serializers.get_user_info", + return_value={ + "email": email_upper, + "first_name": "John", + "last_name": "Smith", + "google_user_id": google_user.google_user_id, + }, + ) + + github_auth_url = reverse("api-v1:custom_auth:oauth:github-oauth-login") + google_auth_url = reverse("api-v1:custom_auth:oauth:google-oauth-login") + + # When + auth_with_github_response = api_client.post( + github_auth_url, data={"access_token": "some-token"} + ) + auth_with_google_response = api_client.post( + google_auth_url, data={"access_token": "some-token"} + ) + + # Then + github_auth_key = auth_with_github_response.json().get("key") + assert github_auth_key == github_user.auth_token.key + + google_auth_key = auth_with_google_response.json().get("key") + assert google_auth_key == google_user.auth_token.key