From b0ef0134cf40703de225ffa3ad4363fee4f8f997 Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Wed, 18 Oct 2023 16:09:47 +0530 Subject: [PATCH] fix: rate limit admin endpoints (#2703) --- api/api_keys/user.py | 4 +++ api/app/settings/common.py | 3 +++ api/app/settings/test.py | 1 + .../identities/tests/test_views.py | 18 +++++++++++++ .../identities/traits/tests/test_views.py | 25 +++++++++++++++++++ api/environments/identities/traits/views.py | 1 + api/environments/identities/views.py | 1 + api/environments/models.py | 2 +- api/environments/sdk/views.py | 1 + api/features/tests/test_views.py | 17 +++++++++++++ api/features/views.py | 1 + .../environments/test_clone_environment.py | 8 +++--- ...test_environments_views_sdk_environment.py | 19 ++++++++++++++ 13 files changed, 96 insertions(+), 5 deletions(-) diff --git a/api/api_keys/user.py b/api/api_keys/user.py index 2cd32f85c3b9..33d370ffcb52 100644 --- a/api/api_keys/user.py +++ b/api/api_keys/user.py @@ -27,6 +27,10 @@ def __init__(self, key: MasterAPIKey): def is_authenticated(self) -> bool: return True + @property + def pk(self) -> str: + return self.key.id + @property def is_master_api_key_user(self) -> bool: return True diff --git a/api/app/settings/common.py b/api/app/settings/common.py index 63b9247ab372..1b69b80b5b40 100644 --- a/api/app/settings/common.py +++ b/api/app/settings/common.py @@ -218,6 +218,7 @@ LOGIN_THROTTLE_RATE = env("LOGIN_THROTTLE_RATE", "20/min") SIGNUP_THROTTLE_RATE = env("SIGNUP_THROTTLE_RATE", "10000/min") +USER_THROTTLE_RATE = env("USER_THROTTLE_RATE", "500/min") REST_FRAMEWORK = { "DEFAULT_PERMISSION_CLASSES": ["rest_framework.permissions.IsAuthenticated"], "DEFAULT_AUTHENTICATION_CLASSES": ( @@ -227,11 +228,13 @@ "PAGE_SIZE": 10, "UNICODE_JSON": False, "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination", + "DEFAULT_THROTTLE_CLASSES": ["rest_framework.throttling.UserRateThrottle"], "DEFAULT_THROTTLE_RATES": { "login": LOGIN_THROTTLE_RATE, "signup": SIGNUP_THROTTLE_RATE, "mfa_code": "5/min", "invite": "10/min", + "user": USER_THROTTLE_RATE, }, "DEFAULT_FILTER_BACKENDS": ["django_filters.rest_framework.DjangoFilterBackend"], "DEFAULT_RENDERER_CLASSES": [ diff --git a/api/app/settings/test.py b/api/app/settings/test.py index 9823dde8aeea..0bd851a28829 100644 --- a/api/app/settings/test.py +++ b/api/app/settings/test.py @@ -9,4 +9,5 @@ "mfa_code": "5/min", "invite": "10/min", "signup": "100/min", + "user": "100000/day", } diff --git a/api/environments/identities/tests/test_views.py b/api/environments/identities/tests/test_views.py index 2b1f5352d376..6086c707b002 100644 --- a/api/environments/identities/tests/test_views.py +++ b/api/environments/identities/tests/test_views.py @@ -25,6 +25,24 @@ from util.tests import Helper +def test_get_identities_is_not_throttled_by_user_throttle( + environment, feature, identity, api_client, settings +): + # Given + settings.REST_FRAMEWORK = {"DEFAULT_THROTTLE_RATES": {"user": "1/minute"}} + + api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key) + base_url = reverse("api-v1:sdk-identities") + url = f"{base_url}?identifier={identity.identifier}" + + # When + for _ in range(10): + response = api_client.get(url) + + # Then + assert response.status_code == status.HTTP_200_OK + + @pytest.mark.django_db class IdentityTestCase(TestCase): identifier = "user1" diff --git a/api/environments/identities/traits/tests/test_views.py b/api/environments/identities/traits/tests/test_views.py index d9a97f0e42f0..1f7f68c89c3d 100644 --- a/api/environments/identities/traits/tests/test_views.py +++ b/api/environments/identities/traits/tests/test_views.py @@ -775,3 +775,28 @@ def test_delete_trait_only_deletes_traits_in_current_environment(self): # and assert Trait.objects.filter(pk=trait_2.id).exists() + + +def test_set_trait_for_an_identity_is_not_throttled_by_user_throttle( + settings, identity, environment, api_client +): + # Given + settings.REST_FRAMEWORK = {"DEFAULT_THROTTLE_RATES": {"user": "1/minute"}} + + api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key) + + url = reverse("api-v1:sdk-traits-list") + data = { + "identity": {"identifier": identity.identifier}, + "trait_key": "key", + "trait_value": "value", + } + + # When + for _ in range(10): + res = api_client.post( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert res.status_code == status.HTTP_200_OK diff --git a/api/environments/identities/traits/views.py b/api/environments/identities/traits/views.py index 6a329ecb1000..d96932cebdd0 100644 --- a/api/environments/identities/traits/views.py +++ b/api/environments/identities/traits/views.py @@ -181,6 +181,7 @@ def post(self, request, identifier, trait_key, *args, **kwargs): class SDKTraits(mixins.CreateModelMixin, viewsets.GenericViewSet): permission_classes = (EnvironmentKeyPermissions, TraitPersistencePermissions) authentication_classes = (EnvironmentKeyAuthentication,) + throttle_classes = [] def get_serializer_class(self): if self.action == "increment_value": diff --git a/api/environments/identities/views.py b/api/environments/identities/views.py index 77a891361820..dec1b52dd672 100644 --- a/api/environments/identities/views.py +++ b/api/environments/identities/views.py @@ -152,6 +152,7 @@ def get(self, request, identifier, *args, **kwargs): class SDKIdentities(SDKAPIView): serializer_class = IdentifyWithTraitsSerializer pagination_class = None # set here to ensure documentation is correct + throttle_classes = [] @swagger_auto_schema( responses={200: SDKIdentitiesResponseSerializer()}, diff --git a/api/environments/models.py b/api/environments/models.py index 112bbdb7683b..fb714401443b 100644 --- a/api/environments/models.py +++ b/api/environments/models.py @@ -163,7 +163,7 @@ def clone(self, name: str, api_key: str = None) -> "Environment": clone.api_key = api_key if api_key else create_hash() clone.save() - # Since identities are closely tied to the enviroment + # Since identities are closely tied to the environment # it does not make much sense to clone them, hence # only clone feature states without identities for feature_state in self.feature_states.filter(identity=None): diff --git a/api/environments/sdk/views.py b/api/environments/sdk/views.py index b7ac0acb721f..6d7df10f1667 100644 --- a/api/environments/sdk/views.py +++ b/api/environments/sdk/views.py @@ -10,6 +10,7 @@ class SDKEnvironmentAPIView(APIView): permission_classes = (EnvironmentKeyPermissions,) + throttle_classes = [] def get_authenticators(self): return [EnvironmentKeyAuthentication(required_key_prefix="ser.")] diff --git a/api/features/tests/test_views.py b/api/features/tests/test_views.py index 062104ea32e2..826d07ad3314 100644 --- a/api/features/tests/test_views.py +++ b/api/features/tests/test_views.py @@ -816,3 +816,20 @@ def test_create_segment_override(admin_client, feature, segment, environment): assert created_override is not None assert created_override.enabled is enabled assert created_override.get_feature_state_value() == string_value + + +def test_get_flags_is_not_throttled_by_user_throttle( + api_client, environment, feature, settings +): + # Given + settings.REST_FRAMEWORK = {"DEFAULT_THROTTLE_RATES": {"user": "1/minute"}} + api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key) + + url = reverse("api-v1:flags") + + # When + for _ in range(10): + response = api_client.get(url) + + # Then + assert response.status_code == status.HTTP_200_OK diff --git a/api/features/views.py b/api/features/views.py index 9aed9e4fd8e3..4d8831b20e6d 100644 --- a/api/features/views.py +++ b/api/features/views.py @@ -551,6 +551,7 @@ class SDKFeatureStates(GenericAPIView): permission_classes = (EnvironmentKeyPermissions,) authentication_classes = (EnvironmentKeyAuthentication,) renderer_classes = [JSONRenderer] + throttle_classes = [] pagination_class = None @swagger_auto_schema( diff --git a/api/tests/integration/environments/test_clone_environment.py b/api/tests/integration/environments/test_clone_environment.py index 96a1d9a2b3ab..dd83f99aa46a 100644 --- a/api/tests/integration/environments/test_clone_environment.py +++ b/api/tests/integration/environments/test_clone_environment.py @@ -18,7 +18,7 @@ def test_clone_environment_clones_feature_states_with_value( client, project, environment, environment_api_key, feature ): - # Firstly, let's update feature state value of the source enviroment + # Firstly, let's update feature state value of the source environment # fetch the feature state id to update feature_state = get_env_feature_states_list_with_api( client, {"environment": environment, "feature": feature} @@ -52,7 +52,7 @@ def test_clone_environment_clones_feature_states_with_value( client, {"environment": environment} ) - # Now, fetch the feature states of the clone enviroment + # Now, fetch the feature states of the clone environment clone_env_feature_states = get_env_feature_states_list_with_api( client, {"environment": res.json()["id"]} ) @@ -81,13 +81,13 @@ def test_clone_environment_clones_feature_states_with_value( def test_clone_environment_creates_admin_permission_with_the_current_user( admin_user, admin_client, environment, environment_api_key ): - # Firstly, let's create the clone of the enviroment + # Firstly, let's create the clone of the environment env_name = "Cloned env" url = reverse("api-v1:environments:environment-clone", args=[environment_api_key]) res = admin_client.post(url, {"name": env_name}) clone_env_api_key = res.json()["api_key"] - # Now, fetch the permission of the newly creatd enviroment + # Now, fetch the permission of the newly creatd environment perm_url = reverse( "api-v1:environments:environment-user-permissions-list", args=[clone_env_api_key], diff --git a/api/tests/unit/environments/test_environments_views_sdk_environment.py b/api/tests/unit/environments/test_environments_views_sdk_environment.py index 0b1554c45837..9dc3983b1466 100644 --- a/api/tests/unit/environments/test_environments_views_sdk_environment.py +++ b/api/tests/unit/environments/test_environments_views_sdk_environment.py @@ -81,3 +81,22 @@ def test_get_environment_document_fails_with_invalid_key( # We get a 403 since only the server side API keys are able to access the # environment document assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_get_environment_document_is_not_throttled_by_user_throttle( + environment, feature, settings, environment_api_key +): + # Given + settings.REST_FRAMEWORK = {"DEFAULT_THROTTLE_RATES": {"user": "1/minute"}} + + client = APIClient() + client.credentials(HTTP_X_ENVIRONMENT_KEY=environment_api_key.key) + + url = reverse("api-v1:environment-document") + + # When + for _ in range(10): + response = client.get(url) + + # Then + assert response.status_code == status.HTTP_200_OK