Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(redis_cache): extend DefaultClient class to add support for RedisClusterException #3392

Merged
merged 6 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/app/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@
"USER_THROTTLE_CACHE_BACKEND", "django.core.cache.backends.locmem.LocMemCache"
)
USER_THROTTLE_CACHE_LOCATION = env.str("USER_THROTTLE_CACHE_LOCATION", "admin-throttle")
USER_THROTTLE_CACHE_OPTIONS = env.dict("USER_THROTTLE_CACHE_OPTIONS", default={})

# Using Redis for cache
# To use Redis for caching, set the cache backend to `django_redis.cache.RedisCache`.
Expand Down Expand Up @@ -705,6 +706,7 @@
USER_THROTTLE_CACHE_NAME: {
"BACKEND": USER_THROTTLE_CACHE_BACKEND,
"LOCATION": USER_THROTTLE_CACHE_LOCATION,
"OPTIONS": USER_THROTTLE_CACHE_OPTIONS,
},
}

Expand Down
77 changes: 66 additions & 11 deletions api/core/redis_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,63 @@
# settings.py

DJANGO_REDIS_CONNECTION_FACTORY = "core.redis_cluster.ClusterConnectionFactory"

"cache_name: {
"BACKEND": ...,
"LOCATION": ...,
"OPTIONS": {
"CLIENT_CLASS": "core.redis_cluster.SafeRedisClusterClient",

},
},
"""

import threading
from copy import deepcopy

from django.core.exceptions import ImproperlyConfigured
from django_redis.client.default import DefaultClient
from django_redis.exceptions import ConnectionInterrupted
from django_redis.pool import ConnectionFactory
from redis.backoff import DecorrelatedJitterBackoff
from redis.cluster import RedisCluster
from redis.exceptions import RedisClusterException
from redis.retry import Retry


class SafeRedisClusterClient(DefaultClient):
SAFE_METHODS = [
"set",
"get",
"incr_version",
"delete",
"delete_pattern",
"delete_many",
"clear",
"get_many",
"set_many",
"incr",
"has_key",
"keys",
]

def _safe_operation(self, func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except RedisClusterException as e:
raise ConnectionInterrupted(connection=None) from e

return wrapper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: probably should be a staticmethod.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, done


def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Dynamically generate safe versions of methods
for method_name in self.SAFE_METHODS:
setattr(
self, method_name, self._safe_operation(getattr(super(), method_name))
)


class ClusterConnectionFactory(ConnectionFactory):
Expand Down Expand Up @@ -57,17 +106,23 @@ def get_connection(self, connection_params: dict) -> RedisCluster:
If we find conflicting client and connection kwargs, we'll raise an
error.
"""
client_cls_kwargs = deepcopy(self.redis_client_cls_kwargs)
# ... and smash 'em together (crashing if there's conflicts)...
for key, value in connection_params.items():
if key in client_cls_kwargs:
raise ImproperlyConfigured(
f"Found '{key}' in both the connection and the client kwargs"
)
client_cls_kwargs[key] = value

# ... and then build and return the client
return RedisCluster(**client_cls_kwargs)
try:
client_cls_kwargs = deepcopy(self.redis_client_cls_kwargs)
# ... and smash 'em together (crashing if there's conflicts)...
for key, value in connection_params.items():
if key in client_cls_kwargs:
raise ImproperlyConfigured(
f"Found '{key}' in both the connection and the client kwargs"
)
client_cls_kwargs[key] = value

# Add explicit retry
client_cls_kwargs["retry"] = Retry(DecorrelatedJitterBackoff(), 3)
# ... and then build and return the client
return RedisCluster(**client_cls_kwargs)
except Exception as e:
# Let django redis handle the exception
raise ConnectionInterrupted(connection=None) from e

def disconnect(self, connection: RedisCluster):
connection.disconnect_connection_pools()
72 changes: 68 additions & 4 deletions api/tests/unit/core/test_redis_cluster.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
from core.redis_cluster import ClusterConnectionFactory
from django.core.exceptions import ImproperlyConfigured
from core.redis_cluster import ClusterConnectionFactory, SafeRedisClusterClient
from django_redis.exceptions import ConnectionInterrupted
from pytest_mock import MockerFixture
from redis.exceptions import RedisClusterException


def test_cluster_connection_factory__connect_cache(mocker: MockerFixture):
Expand Down Expand Up @@ -41,6 +42,8 @@ def test_cluster_connection_factory__get_connection_with_non_conflicting_params(
):
# Given
mockRedisCluster = mocker.patch("core.redis_cluster.RedisCluster")
mockedRetry = mocker.patch("core.redis_cluster.Retry")
mockedBackoff = mocker.patch("core.redis_cluster.DecorrelatedJitterBackoff")
connection_factory = ClusterConnectionFactory(
options={"REDIS_CLIENT_KWARGS": {"decode_responses": False}}
)
Expand All @@ -51,8 +54,12 @@ def test_cluster_connection_factory__get_connection_with_non_conflicting_params(

# Then
mockRedisCluster.assert_called_once_with(
decode_responses=False, host="localhost", port=6379
decode_responses=False,
host="localhost",
port=6379,
retry=mockedRetry.return_value,
)
mockedRetry.assert_called_once_with(mockedBackoff(), 3)


def test_cluster_connection_factory__get_connection_with_conflicting_params(
Expand All @@ -66,7 +73,7 @@ def test_cluster_connection_factory__get_connection_with_conflicting_params(
connection_params = {"decode_responses": True}

# When
with pytest.raises(ImproperlyConfigured):
with pytest.raises(ConnectionInterrupted):
connection_factory.get_connection(connection_params)

# Then - ImproperlyConfigured exception is raised
Expand All @@ -85,3 +92,60 @@ def test_disconnect(mocker: MockerFixture):

# Then
mock_disconnect_connection_pools.assert_called_once()


def test_safe_redis_cluster__safe_methods_raise_connection_interrupted(
mocker: MockerFixture, settings
):
# Given
settings.DJANGO_REDIS_CONNECTION_FACTORY = (
"core.redis_cluster.ClusterConnectionFactory"
)

# Internal client that will raise RedisClusterException on every call
mocked_redis_cluster_client = mocker.MagicMock(side_effect=RedisClusterException)

safe_redis_cluster_client = SafeRedisClusterClient("redis://test", {}, None)

# Replace the internal client with the mocked one
safe_redis_cluster_client.get_client = mocked_redis_cluster_client

# Mock the backend as well
safe_redis_cluster_client._backend = mocker.MagicMock()

# When / Then
with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.get("key")

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.set("key", "value")

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.incr_version("key")

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.delete("key")

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.delete_pattern("key")

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.delete_many(["key"])

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.clear()

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.get_many(["key"])

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.set_many({"key": "value"})

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.incr("key")

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.has_key("key")

with pytest.raises(ConnectionInterrupted):
safe_redis_cluster_client.keys("key")
6 changes: 5 additions & 1 deletion infrastructure/aws/staging/ecs-task-definition-web.json
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@
{
"name": "USER_THROTTLE_CACHE_LOCATION",
"value": "rediss://serverless-redis-cache-7u3xil.serverless.euw2.cache.amazonaws.com:6379"
},
{
"name": "USER_THROTTLE_CACHE_OPTIONS",
"value": "CLIENT_CLASS=core.redis_cluster.SafeRedisClusterClient"
}
],
"secrets": [
Expand Down Expand Up @@ -253,4 +257,4 @@
],
"cpu": "1024",
"memory": "2048"
}
}
Loading