Skip to content

Commit

Permalink
fix(redis_cache): extend DefaultClient class to add support for Redis…
Browse files Browse the repository at this point in the history
…ClusterException (#3392)
  • Loading branch information
gagantrivedi authored Feb 8, 2024
1 parent 14816a3 commit 0949963
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 23 deletions.
4 changes: 2 additions & 2 deletions api/app/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,14 +638,13 @@
"USER_THROTTLE_CACHE_BACKEND", "django.core.cache.backends.locmem.LocMemCache"
)
USER_THROTTLE_CACHE_LOCATION = env.str("USER_THROTTLE_CACHE_LOCATION", "admin-throttle")
USER_THROTTLE_CACHE_OPTIONS = env.dict("USER_THROTTLE_CACHE_OPTIONS", default={})

# Using Redis for cache
# To use Redis for caching, set the cache backend to `django_redis.cache.RedisCache`.
# and set the cache location to the redis url
# ref: https://github.com/jazzband/django-redis/tree/5.4.0#configure-as-cache-backend

# Set this to `core.redis_cluster.ClusterConnectionFactory` when using Redis Cluster.
DJANGO_REDIS_CONNECTION_FACTORY = env.str("DJANGO_REDIS_CONNECTION_FACTORY", "")

# Avoid raising exceptions if redis is down
# ref: https://github.com/jazzband/django-redis/tree/5.4.0#memcached-exceptions-behavior
Expand Down Expand Up @@ -708,6 +707,7 @@
USER_THROTTLE_CACHE_NAME: {
"BACKEND": USER_THROTTLE_CACHE_BACKEND,
"LOCATION": USER_THROTTLE_CACHE_LOCATION,
"OPTIONS": USER_THROTTLE_CACHE_OPTIONS,
},
}

Expand Down
81 changes: 69 additions & 12 deletions api/core/redis_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,66 @@
```python
# settings.py
DJANGO_REDIS_CONNECTION_FACTORY = "core.redis_cluster.ClusterConnectionFactory"
"cache_name: {
"BACKEND": ...,
"LOCATION": ...,
"OPTIONS": {
"CLIENT_CLASS": "core.redis_cluster.SafeRedisClusterClient",
},
},
"""

import threading
from copy import deepcopy

from django.core.exceptions import ImproperlyConfigured
from django_redis.client.default import DefaultClient
from django_redis.exceptions import ConnectionInterrupted
from django_redis.pool import ConnectionFactory
from redis.backoff import DecorrelatedJitterBackoff
from redis.cluster import RedisCluster
from redis.exceptions import RedisClusterException
from redis.retry import Retry


class SafeRedisClusterClient(DefaultClient):
SAFE_METHODS = [
"set",
"get",
"incr_version",
"delete",
"delete_pattern",
"delete_many",
"clear",
"get_many",
"set_many",
"incr",
"has_key",
"keys",
]

@staticmethod
def _safe_operation(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except RedisClusterException as e:
raise ConnectionInterrupted(connection=None) from e

return wrapper

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Dynamically generate safe versions of methods
for method_name in self.SAFE_METHODS:
setattr(
self, method_name, self._safe_operation(getattr(super(), method_name))
)

# Let's use our own connection factory here
self.connection_factory = ClusterConnectionFactory(options=self._options)


class ClusterConnectionFactory(ConnectionFactory):
Expand Down Expand Up @@ -57,17 +108,23 @@ def get_connection(self, connection_params: dict) -> RedisCluster:
If we find conflicting client and connection kwargs, we'll raise an
error.
"""
client_cls_kwargs = deepcopy(self.redis_client_cls_kwargs)
# ... and smash 'em together (crashing if there's conflicts)...
for key, value in connection_params.items():
if key in client_cls_kwargs:
raise ImproperlyConfigured(
f"Found '{key}' in both the connection and the client kwargs"
)
client_cls_kwargs[key] = value

# ... and then build and return the client
return RedisCluster(**client_cls_kwargs)
try:
client_cls_kwargs = deepcopy(self.redis_client_cls_kwargs)
# ... and smash 'em together (crashing if there's conflicts)...
for key, value in connection_params.items():
if key in client_cls_kwargs:
raise ImproperlyConfigured(
f"Found '{key}' in both the connection and the client kwargs"
)
client_cls_kwargs[key] = value

# Add explicit retry
client_cls_kwargs["retry"] = Retry(DecorrelatedJitterBackoff(), 3)
# ... and then build and return the client
return RedisCluster(**client_cls_kwargs)
except Exception as e:
# Let django redis handle the exception
raise ConnectionInterrupted(connection=None) from e

def disconnect(self, connection: RedisCluster):
connection.disconnect_connection_pools()
68 changes: 64 additions & 4 deletions api/tests/unit/core/test_redis_cluster.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
from core.redis_cluster import ClusterConnectionFactory
from django.core.exceptions import ImproperlyConfigured
from core.redis_cluster import ClusterConnectionFactory, SafeRedisClusterClient
from django_redis.exceptions import ConnectionInterrupted
from pytest_mock import MockerFixture
from redis.exceptions import RedisClusterException


def test_cluster_connection_factory__connect_cache(mocker: MockerFixture):
Expand Down Expand Up @@ -41,6 +42,8 @@ def test_cluster_connection_factory__get_connection_with_non_conflicting_params(
):
# Given
mockRedisCluster = mocker.patch("core.redis_cluster.RedisCluster")
mockedRetry = mocker.patch("core.redis_cluster.Retry")
mockedBackoff = mocker.patch("core.redis_cluster.DecorrelatedJitterBackoff")
connection_factory = ClusterConnectionFactory(
options={"REDIS_CLIENT_KWARGS": {"decode_responses": False}}
)
Expand All @@ -51,8 +54,12 @@ def test_cluster_connection_factory__get_connection_with_non_conflicting_params(

# Then
mockRedisCluster.assert_called_once_with(
decode_responses=False, host="localhost", port=6379
decode_responses=False,
host="localhost",
port=6379,
retry=mockedRetry.return_value,
)
mockedRetry.assert_called_once_with(mockedBackoff(), 3)


def test_cluster_connection_factory__get_connection_with_conflicting_params(
Expand All @@ -66,7 +73,7 @@ def test_cluster_connection_factory__get_connection_with_conflicting_params(
connection_params = {"decode_responses": True}

# When
with pytest.raises(ImproperlyConfigured):
with pytest.raises(ConnectionInterrupted):
connection_factory.get_connection(connection_params)

# Then - ImproperlyConfigured exception is raised
Expand All @@ -85,3 +92,56 @@ def test_disconnect(mocker: MockerFixture):

# Then
mock_disconnect_connection_pools.assert_called_once()


def test_safe_redis_cluster__safe_methods_raise_connection_interrupted(
mocker: MockerFixture, settings
):
# Given
# Internal client that will raise RedisClusterException on every call
mocked_redis_cluster_client = mocker.MagicMock(side_effect=RedisClusterException)

safe_redis_cluster_client = SafeRedisClusterClient("redis://test", {}, None)

# Replace the internal client with the mocked one
safe_redis_cluster_client.get_client = mocked_redis_cluster_client

# Mock the backend as well
safe_redis_cluster_client._backend = mocker.MagicMock()

# When / Then
with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.get("key")

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.set("key", "value")

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.incr_version("key")

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.delete("key")

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.delete_pattern("key")

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.delete_many(["key"])

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.clear()

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.get_many(["key"])

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.set_many({"key": "value"})

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.incr("key")

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.has_key("key")

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.keys("key")
10 changes: 5 additions & 5 deletions infrastructure/aws/staging/ecs-task-definition-web.json
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,17 @@
"name": "DEFAULT_THROTTLE_CLASSES",
"value": "core.throttling.UserRateThrottle"
},
{
"name": "DJANGO_REDIS_CONNECTION_FACTORY",
"value": "core.redis_cluster.ClusterConnectionFactory"
},
{
"name": "USER_THROTTLE_CACHE_BACKEND",
"value": "django_redis.cache.RedisCache"
},
{
"name": "USER_THROTTLE_CACHE_LOCATION",
"value": "rediss://serverless-redis-cache-7u3xil.serverless.euw2.cache.amazonaws.com:6379"
},
{
"name": "USER_THROTTLE_CACHE_OPTIONS",
"value": "CLIENT_CLASS=core.redis_cluster.SafeRedisClusterClient"
}
],
"secrets": [
Expand Down Expand Up @@ -253,4 +253,4 @@
],
"cpu": "1024",
"memory": "2048"
}
}

0 comments on commit 0949963

Please sign in to comment.