From 0949963a804e4d9aa69120d48181c220f9bdd375 Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Thu, 8 Feb 2024 14:15:38 +0530 Subject: [PATCH] fix(redis_cache): extend DefaultClient class to add support for RedisClusterException (#3392) --- api/app/settings/common.py | 4 +- api/core/redis_cluster.py | 81 ++++++++++++++++--- api/tests/unit/core/test_redis_cluster.py | 68 +++++++++++++++- .../aws/staging/ecs-task-definition-web.json | 10 +-- 4 files changed, 140 insertions(+), 23 deletions(-) diff --git a/api/app/settings/common.py b/api/app/settings/common.py index e7502412cfd3..ec551025a4c0 100644 --- a/api/app/settings/common.py +++ b/api/app/settings/common.py @@ -638,14 +638,13 @@ "USER_THROTTLE_CACHE_BACKEND", "django.core.cache.backends.locmem.LocMemCache" ) USER_THROTTLE_CACHE_LOCATION = env.str("USER_THROTTLE_CACHE_LOCATION", "admin-throttle") +USER_THROTTLE_CACHE_OPTIONS = env.dict("USER_THROTTLE_CACHE_OPTIONS", default={}) # Using Redis for cache # To use Redis for caching, set the cache backend to `django_redis.cache.RedisCache`. # and set the cache location to the redis url # ref: https://github.com/jazzband/django-redis/tree/5.4.0#configure-as-cache-backend -# Set this to `core.redis_cluster.ClusterConnectionFactory` when using Redis Cluster. -DJANGO_REDIS_CONNECTION_FACTORY = env.str("DJANGO_REDIS_CONNECTION_FACTORY", "") # Avoid raising exceptions if redis is down # ref: https://github.com/jazzband/django-redis/tree/5.4.0#memcached-exceptions-behavior @@ -708,6 +707,7 @@ USER_THROTTLE_CACHE_NAME: { "BACKEND": USER_THROTTLE_CACHE_BACKEND, "LOCATION": USER_THROTTLE_CACHE_LOCATION, + "OPTIONS": USER_THROTTLE_CACHE_OPTIONS, }, } diff --git a/api/core/redis_cluster.py b/api/core/redis_cluster.py index 985760b7d078..5f237aa25f8f 100644 --- a/api/core/redis_cluster.py +++ b/api/core/redis_cluster.py @@ -11,15 +11,66 @@ ```python # settings.py -DJANGO_REDIS_CONNECTION_FACTORY = "core.redis_cluster.ClusterConnectionFactory" +"cache_name: { + "BACKEND": ..., + "LOCATION": ..., + "OPTIONS": { + "CLIENT_CLASS": "core.redis_cluster.SafeRedisClusterClient", + + }, + }, """ import threading from copy import deepcopy from django.core.exceptions import ImproperlyConfigured +from django_redis.client.default import DefaultClient +from django_redis.exceptions import ConnectionInterrupted from django_redis.pool import ConnectionFactory +from redis.backoff import DecorrelatedJitterBackoff from redis.cluster import RedisCluster +from redis.exceptions import RedisClusterException +from redis.retry import Retry + + +class SafeRedisClusterClient(DefaultClient): + SAFE_METHODS = [ + "set", + "get", + "incr_version", + "delete", + "delete_pattern", + "delete_many", + "clear", + "get_many", + "set_many", + "incr", + "has_key", + "keys", + ] + + @staticmethod + def _safe_operation(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except RedisClusterException as e: + raise ConnectionInterrupted(connection=None) from e + + return wrapper + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Dynamically generate safe versions of methods + for method_name in self.SAFE_METHODS: + setattr( + self, method_name, self._safe_operation(getattr(super(), method_name)) + ) + + # Let's use our own connection factory here + self.connection_factory = ClusterConnectionFactory(options=self._options) class ClusterConnectionFactory(ConnectionFactory): @@ -57,17 +108,23 @@ def get_connection(self, connection_params: dict) -> RedisCluster: If we find conflicting client and connection kwargs, we'll raise an error. """ - client_cls_kwargs = deepcopy(self.redis_client_cls_kwargs) - # ... and smash 'em together (crashing if there's conflicts)... - for key, value in connection_params.items(): - if key in client_cls_kwargs: - raise ImproperlyConfigured( - f"Found '{key}' in both the connection and the client kwargs" - ) - client_cls_kwargs[key] = value - - # ... and then build and return the client - return RedisCluster(**client_cls_kwargs) + try: + client_cls_kwargs = deepcopy(self.redis_client_cls_kwargs) + # ... and smash 'em together (crashing if there's conflicts)... + for key, value in connection_params.items(): + if key in client_cls_kwargs: + raise ImproperlyConfigured( + f"Found '{key}' in both the connection and the client kwargs" + ) + client_cls_kwargs[key] = value + + # Add explicit retry + client_cls_kwargs["retry"] = Retry(DecorrelatedJitterBackoff(), 3) + # ... and then build and return the client + return RedisCluster(**client_cls_kwargs) + except Exception as e: + # Let django redis handle the exception + raise ConnectionInterrupted(connection=None) from e def disconnect(self, connection: RedisCluster): connection.disconnect_connection_pools() diff --git a/api/tests/unit/core/test_redis_cluster.py b/api/tests/unit/core/test_redis_cluster.py index 69c8447599e6..6aaef4b89779 100644 --- a/api/tests/unit/core/test_redis_cluster.py +++ b/api/tests/unit/core/test_redis_cluster.py @@ -1,7 +1,8 @@ import pytest -from core.redis_cluster import ClusterConnectionFactory -from django.core.exceptions import ImproperlyConfigured +from core.redis_cluster import ClusterConnectionFactory, SafeRedisClusterClient +from django_redis.exceptions import ConnectionInterrupted from pytest_mock import MockerFixture +from redis.exceptions import RedisClusterException def test_cluster_connection_factory__connect_cache(mocker: MockerFixture): @@ -41,6 +42,8 @@ def test_cluster_connection_factory__get_connection_with_non_conflicting_params( ): # Given mockRedisCluster = mocker.patch("core.redis_cluster.RedisCluster") + mockedRetry = mocker.patch("core.redis_cluster.Retry") + mockedBackoff = mocker.patch("core.redis_cluster.DecorrelatedJitterBackoff") connection_factory = ClusterConnectionFactory( options={"REDIS_CLIENT_KWARGS": {"decode_responses": False}} ) @@ -51,8 +54,12 @@ def test_cluster_connection_factory__get_connection_with_non_conflicting_params( # Then mockRedisCluster.assert_called_once_with( - decode_responses=False, host="localhost", port=6379 + decode_responses=False, + host="localhost", + port=6379, + retry=mockedRetry.return_value, ) + mockedRetry.assert_called_once_with(mockedBackoff(), 3) def test_cluster_connection_factory__get_connection_with_conflicting_params( @@ -66,7 +73,7 @@ def test_cluster_connection_factory__get_connection_with_conflicting_params( connection_params = {"decode_responses": True} # When - with pytest.raises(ImproperlyConfigured): + with pytest.raises(ConnectionInterrupted): connection_factory.get_connection(connection_params) # Then - ImproperlyConfigured exception is raised @@ -85,3 +92,56 @@ def test_disconnect(mocker: MockerFixture): # Then mock_disconnect_connection_pools.assert_called_once() + + +def test_safe_redis_cluster__safe_methods_raise_connection_interrupted( + mocker: MockerFixture, settings +): + # Given + # Internal client that will raise RedisClusterException on every call + mocked_redis_cluster_client = mocker.MagicMock(side_effect=RedisClusterException) + + safe_redis_cluster_client = SafeRedisClusterClient("redis://test", {}, None) + + # Replace the internal client with the mocked one + safe_redis_cluster_client.get_client = mocked_redis_cluster_client + + # Mock the backend as well + safe_redis_cluster_client._backend = mocker.MagicMock() + + # When / Then + with pytest.raises(ConnectionInterrupted): + safe_redis_cluster_client.get("key") + + with pytest.raises(ConnectionInterrupted): + safe_redis_cluster_client.set("key", "value") + + with pytest.raises(ConnectionInterrupted): + safe_redis_cluster_client.incr_version("key") + + with pytest.raises(ConnectionInterrupted): + safe_redis_cluster_client.delete("key") + + with pytest.raises(ConnectionInterrupted): + safe_redis_cluster_client.delete_pattern("key") + + with pytest.raises(ConnectionInterrupted): + safe_redis_cluster_client.delete_many(["key"]) + + with pytest.raises(ConnectionInterrupted): + safe_redis_cluster_client.clear() + + with pytest.raises(ConnectionInterrupted): + safe_redis_cluster_client.get_many(["key"]) + + with pytest.raises(ConnectionInterrupted): + safe_redis_cluster_client.set_many({"key": "value"}) + + with pytest.raises(ConnectionInterrupted): + safe_redis_cluster_client.incr("key") + + with pytest.raises(ConnectionInterrupted): + safe_redis_cluster_client.has_key("key") + + with pytest.raises(ConnectionInterrupted): + safe_redis_cluster_client.keys("key") diff --git a/infrastructure/aws/staging/ecs-task-definition-web.json b/infrastructure/aws/staging/ecs-task-definition-web.json index 5459af0e0ce3..6d87dccca407 100644 --- a/infrastructure/aws/staging/ecs-task-definition-web.json +++ b/infrastructure/aws/staging/ecs-task-definition-web.json @@ -171,10 +171,6 @@ "name": "DEFAULT_THROTTLE_CLASSES", "value": "core.throttling.UserRateThrottle" }, - { - "name": "DJANGO_REDIS_CONNECTION_FACTORY", - "value": "core.redis_cluster.ClusterConnectionFactory" - }, { "name": "USER_THROTTLE_CACHE_BACKEND", "value": "django_redis.cache.RedisCache" @@ -182,6 +178,10 @@ { "name": "USER_THROTTLE_CACHE_LOCATION", "value": "rediss://serverless-redis-cache-7u3xil.serverless.euw2.cache.amazonaws.com:6379" + }, + { + "name": "USER_THROTTLE_CACHE_OPTIONS", + "value": "CLIENT_CLASS=core.redis_cluster.SafeRedisClusterClient" } ], "secrets": [ @@ -253,4 +253,4 @@ ], "cpu": "1024", "memory": "2048" -} \ No newline at end of file +}