Skip to content

Commit

Permalink
fix: oauth user case sensitivity (#4207)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewelwell authored Jun 24, 2024
1 parent 5e87f39 commit af955bf
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 5 deletions.
59 changes: 56 additions & 3 deletions api/custom_auth/oauth/serializers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from abc import abstractmethod

from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.signals import user_logged_in
from django.db.models import F
from rest_framework import serializers
from rest_framework.authtoken.models import Token
from rest_framework.exceptions import PermissionDenied

from organisations.invites.models import Invite
from users.auth_type import AuthType
from users.models import SignUpType

from ..constants import USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE
Expand All @@ -30,6 +34,9 @@ class OAuthLoginSerializer(serializers.Serializer):
write_only=True,
)

auth_type: AuthType | None = None
user_model_id_attribute: str = "id"

class Meta:
abstract = True

Expand All @@ -53,8 +60,28 @@ def create(self, validated_data):
return Token.objects.get_or_create(user=user)[0]

def _get_user(self, user_data: dict):
email = user_data.get("email")
existing_user = UserModel.objects.filter(email=email).first()
email: str = user_data.pop("email")

# There are a number of scenarios that we're catering for in this
# query:
# 1. A new user arriving, and immediately authenticating with
# the given social auth method.
# 2. A user that has previously authenticated with method A is now
# authenticating with method B. Using the `email__iexact` means
# that we'll always retrieve the user that already authenticated
# with A.
# 3. A user that (prior to the case sensitivity fix) authenticated
# with multiple methods and ended up with duplicate user accounts.
# Since it's difficult for us to know which user account they are
# using as their primary, we order by the method they are currently
# authenticating with and grab the first one in the list.
existing_user = (
UserModel.objects.filter(email__iexact=email)
.order_by(
F(self.user_model_id_attribute).desc(nulls_last=True),
)
.first()
)

if not existing_user:
sign_up_type = self.validated_data.get("sign_up_type")
Expand All @@ -65,20 +92,46 @@ def _get_user(self, user_data: dict):
):
raise PermissionDenied(USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE)

return UserModel.objects.create(**user_data, sign_up_type=sign_up_type)
return UserModel.objects.create(
**user_data, email=email.lower(), sign_up_type=sign_up_type
)
elif existing_user.auth_type != self.get_auth_type().value:
# In this scenario, we're seeing a user that had previously
# authenticated with another authentication method and is now
# authenticating with a new OAuth provider.
setattr(
existing_user,
self.user_model_id_attribute,
user_data[self.user_model_id_attribute],
)
existing_user.save()

return existing_user

@abstractmethod
def get_user_info(self):
raise NotImplementedError("`get_user_info()` must be implemented.")

def get_auth_type(self) -> AuthType:
if not self.auth_type: # pragma: no cover
raise NotImplementedError(
"`auth_type` must be set, or `get_auth_type()` must be implemented."
)
return self.auth_type


class GoogleLoginSerializer(OAuthLoginSerializer):
auth_type = AuthType.GOOGLE
user_model_id_attribute = "google_user_id"

def get_user_info(self):
return get_user_info(self.validated_data["access_token"])


class GithubLoginSerializer(OAuthLoginSerializer):
auth_type = AuthType.GITHUB
user_model_id_attribute = "github_user_id"

def get_user_info(self):
github_user = GithubUser(code=self.validated_data["access_token"])
return github_user.get_user_info()
157 changes: 155 additions & 2 deletions api/tests/unit/custom_auth/oauth/test_unit_oauth_views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from unittest import mock

from django.db.models import Model
from django.test import override_settings
from django.urls import reverse
from pytest_mock import MockerFixture
from rest_framework import status
from rest_framework.test import APIClient

Expand Down Expand Up @@ -103,7 +105,12 @@ def test_can_login_with_google_if_registration_disabled(
client = APIClient()

email = "[email protected]"
mock_get_user_info.return_value = {"email": email}
mock_get_user_info.return_value = {
"email": email,
"first_name": "John",
"last_name": "Smith",
"google_user_id": "abc123",
}
django_user_model.objects.create(email=email)

# When
Expand All @@ -126,7 +133,12 @@ def test_can_login_with_github_if_registration_disabled(
email = "[email protected]"
mock_github_user = mock.MagicMock()
MockGithubUser.return_value = mock_github_user
mock_github_user.get_user_info.return_value = {"email": email}
mock_github_user.get_user_info.return_value = {
"email": email,
"first_name": "John",
"last_name": "Smith",
"github_user_id": "abc123",
}
django_user_model.objects.create(email=email)

# When
Expand All @@ -135,3 +147,144 @@ def test_can_login_with_github_if_registration_disabled(
# Then
assert response.status_code == status.HTTP_200_OK
assert "key" in response.json()


def test_login_with_google_updates_existing_user_case_insensitive(
db: None,
django_user_model: type[Model],
mocker: MockerFixture,
api_client: APIClient,
) -> None:
# Given
email_lower = "[email protected]"
email_upper = email_lower.upper()
google_user_id = "abc123"

django_user_model.objects.create(email=email_lower)

mocker.patch(
"custom_auth.oauth.serializers.get_user_info",
return_value={
"email": email_upper,
"first_name": "John",
"last_name": "Smith",
"google_user_id": google_user_id,
},
)

url = reverse("api-v1:custom_auth:oauth:google-oauth-login")

# When
response = api_client.post(url, data={"access_token": "some-token"})

# Then
assert response.status_code == status.HTTP_200_OK

qs = django_user_model.objects.filter(email__iexact=email_lower)
assert qs.count() == 1

user = qs.first()
assert user.email == email_lower
assert user.google_user_id == google_user_id


def test_login_with_github_updates_existing_user_case_insensitive(
db: None,
django_user_model: type[Model],
mocker: MockerFixture,
api_client: APIClient,
) -> None:
# Given
email_lower = "[email protected]"
email_upper = email_lower.upper()
github_user_id = "abc123"

django_user_model.objects.create(email=email_lower)

mock_github_user = mock.MagicMock()
mocker.patch(
"custom_auth.oauth.serializers.GithubUser", return_value=mock_github_user
)
mock_github_user.get_user_info.return_value = {
"email": email_upper,
"first_name": "John",
"last_name": "Smith",
"github_user_id": github_user_id,
}

url = reverse("api-v1:custom_auth:oauth:github-oauth-login")

# When
response = api_client.post(url, data={"access_token": "some-token"})

# Then
assert response.status_code == status.HTTP_200_OK

qs = django_user_model.objects.filter(email__iexact=email_lower)
assert qs.count() == 1

user = qs.first()
assert user.email == email_lower
assert user.github_user_id == github_user_id


def test_user_with_duplicate_accounts_authenticates_as_the_correct_oauth_user(
db: None,
django_user_model: type[Model],
api_client: APIClient,
mocker: MockerFixture,
) -> None:
"""
Specific test to verify the correct behaviour for users affected by
https://github.com/Flagsmith/flagsmith/issues/4185.
"""

# Given
email_lower = "[email protected]"
email_upper = email_lower.upper()

github_user = django_user_model.objects.create(
email=email_lower, github_user_id="abc123"
)
google_user = django_user_model.objects.create(
email=email_upper, google_user_id="abc123"
)

mock_github_user = mock.MagicMock()
mocker.patch(
"custom_auth.oauth.serializers.GithubUser", return_value=mock_github_user
)
mock_github_user.get_user_info.return_value = {
"email": email_lower,
"first_name": "John",
"last_name": "Smith",
"github_user_id": github_user.github_user_id,
}

mocker.patch(
"custom_auth.oauth.serializers.get_user_info",
return_value={
"email": email_upper,
"first_name": "John",
"last_name": "Smith",
"google_user_id": google_user.google_user_id,
},
)

github_auth_url = reverse("api-v1:custom_auth:oauth:github-oauth-login")
google_auth_url = reverse("api-v1:custom_auth:oauth:google-oauth-login")

# When
auth_with_github_response = api_client.post(
github_auth_url, data={"access_token": "some-token"}
)
auth_with_google_response = api_client.post(
google_auth_url, data={"access_token": "some-token"}
)

# Then
github_auth_key = auth_with_github_response.json().get("key")
assert github_auth_key == github_user.auth_token.key

google_auth_key = auth_with_google_response.json().get("key")
assert google_auth_key == google_user.auth_token.key

0 comments on commit af955bf

Please sign in to comment.