diff --git a/api/features/versioning/versioning_service.py b/api/features/versioning/versioning_service.py index 07d1e375bef0..9a2af394ab5f 100644 --- a/api/features/versioning/versioning_service.py +++ b/api/features/versioning/versioning_service.py @@ -46,6 +46,7 @@ def get_environment_flags_list( feature_states = ( FeatureState.objects.select_related( + "environment", "feature", "feature_state_value", "environment_feature_version", diff --git a/api/poetry.lock b/api/poetry.lock index 872a709123a4..90cccc4edf42 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -3206,17 +3206,17 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale [[package]] name = "pytest-django" -version = "4.5.2" +version = "4.7.0" description = "A Django plugin for pytest." optional = false -python-versions = ">=3.5" +python-versions = ">=3.8" files = [ - {file = "pytest-django-4.5.2.tar.gz", hash = "sha256:d9076f759bb7c36939dbdd5ae6633c18edfc2902d1a69fdbefd2426b970ce6c2"}, - {file = "pytest_django-4.5.2-py3-none-any.whl", hash = "sha256:c60834861933773109334fe5a53e83d1ef4828f2203a1d6a0fa9972f4f75ab3e"}, + {file = "pytest-django-4.7.0.tar.gz", hash = "sha256:92d6fd46b1d79b54fb6b060bbb39428073396cec717d5f2e122a990d4b6aa5e8"}, + {file = "pytest_django-4.7.0-py3-none-any.whl", hash = "sha256:4e1c79d5261ade2dd58d91208017cd8f62cb4710b56e012ecd361d15d5d662a2"}, ] [package.dependencies] -pytest = ">=5.4.0" +pytest = ">=7.0.0" [package.extras] docs = ["sphinx", "sphinx-rtd-theme"] @@ -4372,4 +4372,4 @@ requests = ">=2.7,<3.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "5bb4ae1878deaa542af80e69cb9c735f12fd24bc56379176e0b19715004c0c0f" +content-hash = "dc66f4669e76325b8135718c55c95e68bb11805e181c80961cea58af5fc9b2bb" diff --git a/api/pyproject.toml b/api/pyproject.toml index df302d17dcaa..5c6727f2e478 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -135,7 +135,7 @@ pylint = "~2.16.2" pep8 = "~1.7.1" autopep8 = "~2.0.1" pytest = "~7.2.1" -pytest-django = "~4.5.2" +pytest-django = "^4.5.2" black = "~23.7.0" pip-tools = "~6.13.0" pytest-cov = "~4.1.0" diff --git a/api/tests/unit/features/test_unit_features_views.py b/api/tests/unit/features/test_unit_features_views.py index 09f0735d4933..3d3b506c8bfc 100644 --- a/api/tests/unit/features/test_unit_features_views.py +++ b/api/tests/unit/features/test_unit_features_views.py @@ -11,6 +11,7 @@ from django.forms import model_to_dict from django.urls import reverse from django.utils import timezone +from pytest_django import DjangoAssertNumQueries from pytest_lazyfixture import lazy_fixture from rest_framework import status from rest_framework.test import APIClient, APITestCase @@ -915,14 +916,26 @@ def test_get_flags_is_not_throttled_by_user_throttle( def test_list_feature_states_from_simple_view_set( - environment, feature, admin_user, admin_client, django_assert_num_queries -): + environment: Environment, + feature: Feature, + admin_user: FFAdminUser, + admin_client: APIClient, + django_assert_num_queries: DjangoAssertNumQueries, +) -> None: # Given base_url = reverse("api-v1:features:featurestates-list") url = f"{base_url}?environment={environment.id}" # add another feature - Feature.objects.create(name="another_feature", project=environment.project) + feature2 = Feature.objects.create( + name="another_feature", project=environment.project + ) + + # and a new version for the same feature to check for N+1 issues + v1_feature_state = FeatureState.objects.get( + environment=environment, feature=feature2 + ) + v1_feature_state.clone(env=environment, version=2, live_from=timezone.now()) # add another organisation with a project, environment and feature (which should be # excluded) @@ -980,7 +993,16 @@ def test_list_feature_states_nested_environment_view_set( ) # Add another feature - Feature.objects.create(name="another_feature", project=project) + second_feature = Feature.objects.create(name="another_feature", project=project) + + # create some new versions to test N+1 issues + v1_feature_state = FeatureState.objects.get( + feature=second_feature, environment=environment + ) + v2_feature_state = v1_feature_state.clone( + env=environment, version=2, live_from=timezone.now() + ) + v2_feature_state.clone(env=environment, version=3, live_from=timezone.now()) # When with django_assert_num_queries(8): @@ -2315,3 +2337,32 @@ def test_update_segment_override__using_simple_feature_state_viewset__denies_upd # Then assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_list_features_n_plus_1( + staff_client: APIClient, + project: Project, + feature: Feature, + with_project_permissions: Callable, + django_assert_num_queries: DjangoAssertNumQueries, + environment: Environment, +) -> None: + # Given + with_project_permissions([VIEW_PROJECT]) + + base_url = reverse("api-v1:projects:project-features-list", args=[project.id]) + url = f"{base_url}?environment={environment.id}" + + # add some more versions to test for N+1 issues + v1_feature_state = FeatureState.objects.get( + feature=feature, environment=environment + ) + for i in range(2, 4): + v1_feature_state.clone(env=environment, version=i, live_from=timezone.now()) + + # When + with django_assert_num_queries(13): + response = staff_client.get(url) + + # Then + assert response.status_code == status.HTTP_200_OK diff --git a/api/tests/unit/users/test_unit_users_views.py b/api/tests/unit/users/test_unit_users_views.py index acd9d73ccc3a..449e4d7f59cf 100644 --- a/api/tests/unit/users/test_unit_users_views.py +++ b/api/tests/unit/users/test_unit_users_views.py @@ -1,6 +1,5 @@ import json import typing -from unittest import TestCase import pytest from dateutil.relativedelta import relativedelta @@ -9,6 +8,7 @@ from django.contrib.auth.models import AbstractUser from django.contrib.auth.tokens import default_token_generator from django.core import mail +from django.test import TestCase from django.urls import reverse from django.utils import timezone from djoser import utils @@ -174,8 +174,16 @@ def test_admin_can_get_users_in_organisation(self): "api-v1:organisations:organisation-users-list", args=[self.organisation.pk] ) + # add some more users to test for N+1 issues + for i in range(2): + additional_user = FFAdminUser.objects.create( + email=f"additional_user_{i}@org.com" + ) + additional_user.add_organisation(self.organisation) + # When - res = self.client.get(url) + with self.assertNumQueries(5): + res = self.client.get(url) # Then assert res.status_code == status.HTTP_200_OK diff --git a/api/users/models.py b/api/users/models.py index b95b8d8cc6be..7c8f48915dfc 100644 --- a/api/users/models.py +++ b/api/users/models.py @@ -235,12 +235,24 @@ def get_organisation_join_date(self, organisation): def get_user_organisation( self, organisation: typing.Union["Organisation", int] ) -> UserOrganisation: + organisation_id = getattr(organisation, "id", organisation) + try: - return self.userorganisation_set.get(organisation=organisation) - except UserOrganisation.DoesNotExist: + # Since the user list view relies on this data, we prefetch it in + # the queryset, hence we can't use `userorganisation_set.get()` + # and instead use this next(filter()) approach. Since most users + # won't have more than ~1 organisation, we can accept the performance + # hit in the case that we are only getting the organisation for a + # single user. + return next( + filter( + lambda uo: uo.organisation_id == organisation_id, + self.userorganisation_set.all(), + ) + ) + except StopIteration: logger.warning( - "User %d is not part of organisation %d" - % (self.id, getattr(organisation, "id", organisation)) + "User %d is not part of organisation %d" % (self.id, organisation_id) ) def get_permitted_projects( diff --git a/api/users/views.py b/api/users/views.py index 39985f75da8b..5e6cd85890c3 100644 --- a/api/users/views.py +++ b/api/users/views.py @@ -2,7 +2,7 @@ from core.helpers import get_current_site_url from django.contrib.auth.mixins import PermissionRequiredMixin -from django.db.models import Q, QuerySet +from django.db.models import Prefetch, Q, QuerySet from django.http import ( Http404, HttpRequest, @@ -21,7 +21,7 @@ from rest_framework.request import Request from rest_framework.response import Response -from organisations.models import Organisation +from organisations.models import Organisation, UserOrganisation from organisations.permissions.permissions import ( MANAGE_USER_GROUPS, NestedIsOrganisationAdminPermission, @@ -104,9 +104,12 @@ class FFAdminUserViewSet(mixins.ListModelMixin, viewsets.GenericViewSet): def get_queryset(self): if self.kwargs.get("organisation_pk"): - queryset = FFAdminUser.objects.filter( - organisations__id=self.kwargs.get("organisation_pk") - ) + queryset = FFAdminUser.objects.prefetch_related( + Prefetch( + "userorganisation_set", + queryset=UserOrganisation.objects.select_related("organisation"), + ) + ).filter(organisations__id=self.kwargs.get("organisation_pk")) queryset = self._apply_query_filters(queryset) return queryset else: