diff --git a/api/app/settings/common.py b/api/app/settings/common.py index 17e4b586ab61..23ae93804e5b 100644 --- a/api/app/settings/common.py +++ b/api/app/settings/common.py @@ -630,6 +630,32 @@ CACHE_ENVIRONMENT_DOCUMENT_SECONDS = env.int("CACHE_ENVIRONMENT_DOCUMENT_SECONDS", 0) ENVIRONMENT_DOCUMENT_CACHE_LOCATION = "environment-documents" +USER_THROTTLE_CACHE_NAME = "user-throttle" +USER_THROTTLE_CACHE_BACKEND = env.str( + "USER_THROTTLE_CACHE_BACKEND", "django.core.cache.backends.locmem.LocMemCache" +) +USER_THROTTLE_CACHE_LOCATION = env.str("USER_THROTTLE_CACHE_LOCATION", "admin-throttle") + +# Using Redis for cache +# To use Redis for caching, set the cache backend to `django_redis.cache.RedisCache`. +# and set the cache location to the redis url +# ref: https://github.com/jazzband/django-redis/tree/5.4.0#configure-as-cache-backend + +# Set this to `core.redis_cluster.ClusterConnectionFactory` when using Redis Cluster. +DJANGO_REDIS_CONNECTION_FACTORY = env.str("DJANGO_REDIS_CONNECTION_FACTORY", "") + +# Avoid raising exceptions if redis is down +# ref: https://github.com/jazzband/django-redis/tree/5.4.0#memcached-exceptions-behavior +DJANGO_REDIS_IGNORE_EXCEPTIONS = env.bool( + "DJANGO_REDIS_IGNORE_EXCEPTIONS", default=True +) + +# Log exceptions generated by django-redis +# ref:https://github.com/jazzband/django-redis/tree/5.4.0#log-ignored-exceptions +DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS = env.bool( + "DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS", True +) + CACHES = { "default": { "BACKEND": "django.core.cache.backends.locmem.LocMemCache", @@ -676,6 +702,10 @@ "LOCATION": ENVIRONMENT_SEGMENTS_CACHE_LOCATION, "TIMEOUT": ENVIRONMENT_SEGMENTS_CACHE_SECONDS, }, + USER_THROTTLE_CACHE_NAME: { + "BACKEND": USER_THROTTLE_CACHE_BACKEND, + "LOCATION": USER_THROTTLE_CACHE_LOCATION, + }, } TRENCH_AUTH = { diff --git a/api/app/settings/test.py b/api/app/settings/test.py index 1241f1db9d18..6c721d332e98 100644 --- a/api/app/settings/test.py +++ b/api/app/settings/test.py @@ -4,6 +4,7 @@ # We dont want to track tests ENABLE_TELEMETRY = False MAX_PROJECTS_IN_FREE_PLAN = 10 +REST_FRAMEWORK["DEFAULT_THROTTLE_CLASSES"] = ["core.throttling.UserRateThrottle"] REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"] = { "login": "100/min", "mfa_code": "5/min", diff --git a/api/app_analytics/views.py b/api/app_analytics/views.py index 5ce7246213b4..1831334f1d80 100644 --- a/api/app_analytics/views.py +++ b/api/app_analytics/views.py @@ -39,6 +39,7 @@ class SDKAnalyticsFlags(GenericAPIView): permission_classes = (EnvironmentKeyPermissions,) authentication_classes = (EnvironmentKeyAuthentication,) + throttle_classes = [] def get_serializer_class(self): if getattr(self, "swagger_fake_view", False): @@ -116,6 +117,7 @@ class SelfHostedTelemetryAPIView(CreateAPIView): permission_classes = () authentication_classes = () + throttle_classes = [] serializer_class = TelemetrySerializer diff --git a/api/conftest.py b/api/conftest.py index 7954a01bc0b4..fe9ed05b92ac 100644 --- a/api/conftest.py +++ b/api/conftest.py @@ -4,7 +4,7 @@ import boto3 import pytest from django.contrib.contenttypes.models import ContentType -from django.core.cache import cache +from django.core.cache import caches from flag_engine.segments.constants import EQUAL from moto import mock_dynamodb from mypy_boto3_dynamodb.service_resource import DynamoDBServiceResource, Table @@ -350,9 +350,15 @@ def reset_cache(): # https://groups.google.com/g/django-developers/c/zlaPsP13dUY # TL;DR: Use this if your test interacts with cache since django # does not clear cache after every test - cache.clear() + # Clear all caches before the test + for cache in caches.all(): + cache.clear() + yield - cache.clear() + + # Clear all caches after the test + for cache in caches.all(): + cache.clear() @pytest.fixture() diff --git a/api/core/redis_cluster.py b/api/core/redis_cluster.py new file mode 100644 index 000000000000..985760b7d078 --- /dev/null +++ b/api/core/redis_cluster.py @@ -0,0 +1,73 @@ +""" +Temporary module that adds support for Redis Cluster to django-redis by implementing +a connection factory class(`ClusterConnectionFactory`). +This module should be removed once [this](https://github.com/jazzband/django-redis/issues/606) +is resolved. + +Usage: +------ +Include the following configuration in Django project's settings.py file: + +```python +# settings.py + +DJANGO_REDIS_CONNECTION_FACTORY = "core.redis_cluster.ClusterConnectionFactory" +""" + +import threading +from copy import deepcopy + +from django.core.exceptions import ImproperlyConfigured +from django_redis.pool import ConnectionFactory +from redis.cluster import RedisCluster + + +class ClusterConnectionFactory(ConnectionFactory): + """A connection factory for redis.cluster.RedisCluster + The cluster client manages connection pools internally, so we don't want to + do it at this level like the base ConnectionFactory does. + """ + + # A global cache of URL->client so that within a process, we will reuse a + # single client, and therefore a single set of connection pools. + _clients = {} + _clients_lock = threading.Lock() + + def connect(self, url: str) -> RedisCluster: + """Given a connection url, return a client instance. + Prefer to return from our cache but if we don't yet have one build it + to populate the cache. + """ + if url not in self._clients: + with self._clients_lock: + if url not in self._clients: + params = self.make_connection_params(url) + self._clients[url] = self.get_connection(params) + + return self._clients[url] + + def get_connection(self, connection_params: dict) -> RedisCluster: + """ + Given connection_params, return a new client instance. + Basic django-redis ConnectionFactory manages a cache of connection + pools and builds a fresh client each time. because the cluster client + manages its own connection pools, we will instead merge the + "connection" and "client" kwargs and throw them all at the client to + sort out. + If we find conflicting client and connection kwargs, we'll raise an + error. + """ + client_cls_kwargs = deepcopy(self.redis_client_cls_kwargs) + # ... and smash 'em together (crashing if there's conflicts)... + for key, value in connection_params.items(): + if key in client_cls_kwargs: + raise ImproperlyConfigured( + f"Found '{key}' in both the connection and the client kwargs" + ) + client_cls_kwargs[key] = value + + # ... and then build and return the client + return RedisCluster(**client_cls_kwargs) + + def disconnect(self, connection: RedisCluster): + connection.disconnect_connection_pools() diff --git a/api/core/throttling.py b/api/core/throttling.py new file mode 100644 index 000000000000..a69c5260b8a9 --- /dev/null +++ b/api/core/throttling.py @@ -0,0 +1,7 @@ +from django.conf import settings +from django.core.cache import caches +from rest_framework import throttling + + +class UserRateThrottle(throttling.UserRateThrottle): + cache = caches[settings.USER_THROTTLE_CACHE_NAME] diff --git a/api/environments/identities/traits/views.py b/api/environments/identities/traits/views.py index d96932cebdd0..201512643e65 100644 --- a/api/environments/identities/traits/views.py +++ b/api/environments/identities/traits/views.py @@ -113,6 +113,7 @@ class SDKTraitsDeprecated(SDKAPIView): # API to handle /api/v1/identities//traits/ endpoints # if Identity or Trait does not exist it will create one, otherwise will fetch existing serializer_class = TraitSerializerBasic + throttle_classes = [] schema = None diff --git a/api/environments/identities/views.py b/api/environments/identities/views.py index af5750011be1..491031247d20 100644 --- a/api/environments/identities/views.py +++ b/api/environments/identities/views.py @@ -105,6 +105,7 @@ class SDKIdentitiesDeprecated(SDKAPIView): # if Identity does not exist it will create one, otherwise will fetch existing serializer_class = IdentifyWithTraitsSerializer + throttle_classes = [] schema = None diff --git a/api/features/views.py b/api/features/views.py index 75ca0b18df3f..bad7082e6934 100644 --- a/api/features/views.py +++ b/api/features/views.py @@ -602,8 +602,8 @@ class SDKFeatureStates(GenericAPIView): permission_classes = (EnvironmentKeyPermissions,) authentication_classes = (EnvironmentKeyAuthentication,) renderer_classes = [JSONRenderer] - throttle_classes = [] pagination_class = None + throttle_classes = [] @swagger_auto_schema( query_serializer=SDKFeatureStatesQuerySerializer(), diff --git a/api/integrations/sentry/samplers.py b/api/integrations/sentry/samplers.py index 18531ac067aa..f2512af62486 100644 --- a/api/integrations/sentry/samplers.py +++ b/api/integrations/sentry/samplers.py @@ -3,14 +3,14 @@ from django.conf import settings NON_FUNCTIONAL_ENDPOINTS = ("/health", "") -SDK_ENDPOINTS = ( +SDK_ENDPOINTS = { "/api/v1/flags", "/api/v1/identities", "/api/v1/traits", "/api/v1/traits/bulk", "/api/v1/environment-document", "/api/v1/analytics/flags", -) +} def traces_sampler(ctx): diff --git a/api/poetry.lock b/api/poetry.lock index 85b9cde1b7a9..0af1617b2684 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1076,6 +1076,24 @@ django = ">=1.11" ldap3 = ">=2.5,<3" pyasn1 = ">=0.4.6,<0.5" +[[package]] +name = "django-redis" +version = "5.4.0" +description = "Full featured redis cache backend for Django." +optional = false +python-versions = ">=3.6" +files = [ + {file = "django-redis-5.4.0.tar.gz", hash = "sha256:6a02abaa34b0fea8bf9b707d2c363ab6adc7409950b2db93602e6cb292818c42"}, + {file = "django_redis-5.4.0-py3-none-any.whl", hash = "sha256:ebc88df7da810732e2af9987f7f426c96204bf89319df4c6da6ca9a2942edd5b"}, +] + +[package.dependencies] +Django = ">=3.2" +redis = ">=3,<4.0.0 || >4.0.0,<4.0.1 || >4.0.1" + +[package.extras] +hiredis = ["redis[hiredis] (>=3,!=4.0.0,!=4.0.1)"] + [[package]] name = "django-ses" version = "3.5.0" @@ -3471,6 +3489,24 @@ files = [ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] +[[package]] +name = "redis" +version = "5.0.1" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.7" +files = [ + {file = "redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"}, + {file = "redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""} + +[package.extras] +hiredis = ["hiredis (>=1.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] + [[package]] name = "requests" version = "2.31.0" @@ -4436,4 +4472,4 @@ requests = ">=2.7,<3.0" [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "b143e6970298307f2ac4917f4fcf42179ae613fc5c38df947ebc3e3fc129a05f" +content-hash = "726c2c9615317642f03c659b75524f68318d1d7a679a12ce700562a0ff122a33" diff --git a/api/pyproject.toml b/api/pyproject.toml index e827714c67e1..b8a65de963e5 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -103,6 +103,7 @@ pydantic = "~1.10.9" pyngo = "~1.6.0" flagsmith = "^3.4.0" python-gnupg = "^0.5.1" +django-redis = "^5.4.0" [tool.poetry.group.auth-controller] optional = true diff --git a/api/tests/integration/conftest.py b/api/tests/integration/conftest.py index 4d9de2064473..fd1916469077 100644 --- a/api/tests/integration/conftest.py +++ b/api/tests/integration/conftest.py @@ -158,6 +158,18 @@ def sdk_client(environment_api_key): return client +@pytest.fixture() +def server_side_sdk_client( + admin_client: APIClient, environment: int, environment_api_key: str +) -> APIClient: + url = reverse("api-v1:environments:api-keys-list", args={environment_api_key}) + response = admin_client.post(url, data={"name": "Some key"}) + + client = APIClient() + client.credentials(HTTP_X_ENVIRONMENT_KEY=response.json()["key"]) + return client + + @pytest.fixture() def default_feature_value(): return "default_value" diff --git a/api/tests/integration/core/test_user_rate_throttle.py b/api/tests/integration/core/test_user_rate_throttle.py new file mode 100644 index 000000000000..59160fcafb69 --- /dev/null +++ b/api/tests/integration/core/test_user_rate_throttle.py @@ -0,0 +1,153 @@ +import json + +import pytest +from django.urls import reverse +from pytest_lazyfixture import lazy_fixture +from pytest_mock import MockerFixture +from rest_framework import status +from rest_framework.test import APIClient + + +@pytest.mark.parametrize( + "client", + [(lazy_fixture("admin_master_api_key_client")), (lazy_fixture("admin_client"))], +) +def test_user_throttle_can_throttle_admin_endpoints( + client: APIClient, project: int, mocker: MockerFixture, reset_cache: None +) -> None: + # Given + mocker.patch("core.throttling.UserRateThrottle.get_rate", return_value="1/minute") + + url = reverse("api-v1:projects:project-list") + + # Then - first request should be successful + response = client.get(url, content_type="application/json") + assert response.status_code == status.HTTP_200_OK + + # Second request should be throttled + response = client.get(url, content_type="application/json") + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + + +def test_get_flags_is_not_throttled_by_user_throttle( + sdk_client: APIClient, + environment: int, + environment_api_key: str, + mocker: MockerFixture, +) -> None: + # Given + mocker.patch("core.throttling.UserRateThrottle.get_rate", return_value="1/minute") + + url = reverse("api-v1:flags") + + # When + for _ in range(10): + response = sdk_client.get(url) + + # Then + assert response.status_code == status.HTTP_200_OK + + +def test_get_environment_document_is_not_throttled_by_user_throttle( + server_side_sdk_client: APIClient, + environment: int, + environment_api_key: str, + mocker: MockerFixture, +): + # Given + mocker.patch("core.throttling.UserRateThrottle.get_rate", return_value="1/minute") + + url = reverse("api-v1:environment-document") + + # When + for _ in range(10): + response = server_side_sdk_client.get(url) + + # Then + assert response.status_code == status.HTTP_200_OK + + +def test_get_identities_is_not_throttled_by_user_throttle( + environment: int, + sdk_client: APIClient, + mocker: MockerFixture, + identity: int, + identity_identifier: str, +): + # Given + mocker.patch("core.throttling.UserRateThrottle.get_rate", return_value="1/minute") + + base_url = reverse("api-v1:sdk-identities") + url = f"{base_url}?identifier={identity_identifier}" + + # When + for _ in range(10): + response = sdk_client.get(url) + + # Then + assert response.status_code == status.HTTP_200_OK + + +def test_set_trait_for_an_identity_is_not_throttled_by_user_throttle( + environment: int, + server_side_sdk_client: APIClient, + identity: int, + identity_identifier: str, + mocker: MockerFixture, +): + # Given + mocker.patch("core.throttling.UserRateThrottle.get_rate", return_value="1/minute") + url = reverse("api-v1:sdk-traits-list") + data = { + "identity": {"identifier": identity_identifier}, + "trait_key": "key", + "trait_value": "value", + } + + # When + for _ in range(10): + res = server_side_sdk_client.post( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert res.status_code == status.HTTP_200_OK + + +def test_sdk_analytics_is_not_throttled_by_user_throttle( + mocker: MockerFixture, environment: int, sdk_client: APIClient +): + # Given + mocker.patch("core.throttling.UserRateThrottle.get_rate", return_value="1/minute") + + # When + for _ in range(10): + response = sdk_client.post("/api/v1/analytics/flags/") + + # Then + assert response.status_code == status.HTTP_200_OK + + +def test_self_hosted_telemetry_view_is_not_throttled_by_user_throttle( + mocker: MockerFixture, +): + # Given + api_client = APIClient() + mocker.patch("core.throttling.UserRateThrottle.get_rate", return_value="1/minute") + + data = { + "organisations": 1, + "projects": 1, + "environments": 1, + "features": 1, + "segments": 1, + "users": 1, + "debug_enabled": True, + "env": "test", + } + # When + for _ in range(10): + response = api_client.post("/api/v1/analytics/telemetry/", data=data) + + # Then + assert response.status_code == status.HTTP_201_CREATED diff --git a/api/tests/unit/core/test_redis_cluster.py b/api/tests/unit/core/test_redis_cluster.py new file mode 100644 index 000000000000..69c8447599e6 --- /dev/null +++ b/api/tests/unit/core/test_redis_cluster.py @@ -0,0 +1,87 @@ +import pytest +from core.redis_cluster import ClusterConnectionFactory +from django.core.exceptions import ImproperlyConfigured +from pytest_mock import MockerFixture + + +def test_cluster_connection_factory__connect_cache(mocker: MockerFixture): + # Given + mock_get_connection = mocker.patch.object( + ClusterConnectionFactory, "get_connection" + ) + connection_factory = ClusterConnectionFactory(options={}) + + url = "redis://localhost:6379" + make_connection_params = mocker.patch.object( + connection_factory, + "make_connection_params", + return_value={"url": url}, + ) + + # When + first_connection = connection_factory.connect(url) + + # Let's call it again + second_connection = connection_factory.connect(url) + + # Then + assert first_connection == mock_get_connection.return_value + assert second_connection == mock_get_connection.return_value + assert first_connection is second_connection + + # get_connection was only called once + mock_get_connection.assert_called_once_with({"url": url}) + + # make_connection_params was only called once + make_connection_params.assert_called_once_with(url) + + +def test_cluster_connection_factory__get_connection_with_non_conflicting_params( + mocker: MockerFixture, +): + # Given + mockRedisCluster = mocker.patch("core.redis_cluster.RedisCluster") + connection_factory = ClusterConnectionFactory( + options={"REDIS_CLIENT_KWARGS": {"decode_responses": False}} + ) + connection_params = {"host": "localhost", "port": 6379} + + # When + connection_factory.get_connection(connection_params) + + # Then + mockRedisCluster.assert_called_once_with( + decode_responses=False, host="localhost", port=6379 + ) + + +def test_cluster_connection_factory__get_connection_with_conflicting_params( + mocker: MockerFixture, +): + # Given + mockRedisCluster = mocker.patch("core.redis_cluster.RedisCluster") + connection_factory = ClusterConnectionFactory( + options={"REDIS_CLIENT_KWARGS": {"decode_responses": False}} + ) + connection_params = {"decode_responses": True} + + # When + with pytest.raises(ImproperlyConfigured): + connection_factory.get_connection(connection_params) + + # Then - ImproperlyConfigured exception is raised + mockRedisCluster.assert_not_called() + + +def test_disconnect(mocker: MockerFixture): + # Given + connection_factory = ClusterConnectionFactory({}) + mock_connection = mocker.MagicMock() + mock_disconnect_connection_pools = mocker.MagicMock() + mock_connection.disconnect_connection_pools = mock_disconnect_connection_pools + + # When + connection_factory.disconnect(mock_connection) + + # Then + mock_disconnect_connection_pools.assert_called_once() diff --git a/api/tests/unit/environments/identities/test_unit_identities_views.py b/api/tests/unit/environments/identities/test_unit_identities_views.py index 75e39c4c8706..0d83c09c5cbb 100644 --- a/api/tests/unit/environments/identities/test_unit_identities_views.py +++ b/api/tests/unit/environments/identities/test_unit_identities_views.py @@ -33,24 +33,6 @@ from util.tests import Helper -def test_get_identities_is_not_throttled_by_user_throttle( - environment, feature, identity, api_client, settings -): - # Given - settings.REST_FRAMEWORK = {"DEFAULT_THROTTLE_RATES": {"user": "1/minute"}} - - api_client.credentials(HTTP_X_ENVIRONMENT_KEY=environment.api_key) - base_url = reverse("api-v1:sdk-identities") - url = f"{base_url}?identifier={identity.identifier}" - - # When - for _ in range(10): - response = api_client.get(url) - - # Then - assert response.status_code == status.HTTP_200_OK - - @pytest.mark.django_db class IdentityTestCase(TestCase): identifier = "user1" diff --git a/api/tests/unit/environments/test_unit_environments_views_sdk_environment.py b/api/tests/unit/environments/test_unit_environments_views_sdk_environment.py index 5857fa66b895..2a69a3c846e0 100644 --- a/api/tests/unit/environments/test_unit_environments_views_sdk_environment.py +++ b/api/tests/unit/environments/test_unit_environments_views_sdk_environment.py @@ -82,22 +82,3 @@ def test_get_environment_document_fails_with_invalid_key( # We get a 403 since only the server side API keys are able to access the # environment document assert response.status_code == status.HTTP_403_FORBIDDEN - - -def test_get_environment_document_is_not_throttled_by_user_throttle( - environment, feature, settings, environment_api_key -): - # Given - settings.REST_FRAMEWORK = {"DEFAULT_THROTTLE_RATES": {"user": "1/minute"}} - - client = APIClient() - client.credentials(HTTP_X_ENVIRONMENT_KEY=environment_api_key.key) - - url = reverse("api-v1:environment-document") - - # When - for _ in range(10): - response = client.get(url) - - # Then - assert response.status_code == status.HTTP_200_OK diff --git a/infrastructure/aws/staging/ecs-task-definition-web.json b/infrastructure/aws/staging/ecs-task-definition-web.json index f2f9cb4d2f6b..6a6bbbec5b54 100644 --- a/infrastructure/aws/staging/ecs-task-definition-web.json +++ b/infrastructure/aws/staging/ecs-task-definition-web.json @@ -166,6 +166,22 @@ { "name": "USE_POSTGRES_FOR_ANALYTICS", "value": "True" + }, + { + "name": "DEFAULT_THROTTLE_CLASSES", + "value": "core.throttling.UserRateThrottle" + }, + { + "name": "DJANGO_REDIS_CONNECTION_FACTORY", + "value": "core.throttling.UserRateThrottle" + }, + { + "name": "USER_THROTTLE_CACHE_BACKEND", + "value": "django_redis.cache.RedisCache" + }, + { + "name": "USER_THROTTLE_CACHE_LOCATION", + "value": "rediss://serverless-redis-cache-7u3xil.serverless.euw2.cache.amazonaws.com:6379" } ], "secrets": [