Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rate-limit/redis): Use redis to store throttling data for admin endpoints #2863

Merged
merged 6 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions api/app/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,32 @@
CACHE_ENVIRONMENT_DOCUMENT_SECONDS = env.int("CACHE_ENVIRONMENT_DOCUMENT_SECONDS", 0)
ENVIRONMENT_DOCUMENT_CACHE_LOCATION = "environment-documents"

USER_THROTTLE_CACHE_NAME = "user-throttle"
USER_THROTTLE_CACHE_BACKEND = env.str(
"USER_THROTTLE_CACHE_BACKEND", "django.core.cache.backends.locmem.LocMemCache"
)
USER_THROTTLE_CACHE_LOCATION = env.str("USER_THROTTLE_CACHE_LOCATION", "admin-throttle")

# Using Redis for cache
# To use Redis for caching, set the cache backend to `django_redis.cache.RedisCache`.
# and set the cache location to the redis url
# ref: https://github.com/jazzband/django-redis/tree/5.4.0#configure-as-cache-backend

# Set this to `core.redis_cluster.ClusterConnectionFactory` when using Redis Cluster.
DJANGO_REDIS_CONNECTION_FACTORY = env.str("DJANGO_REDIS_CONNECTION_FACTORY", "")

# Avoid raising exceptions if redis is down
# ref: https://github.com/jazzband/django-redis/tree/5.4.0#memcached-exceptions-behavior
DJANGO_REDIS_IGNORE_EXCEPTIONS = env.bool(
"DJANGO_REDIS_IGNORE_EXCEPTIONS", default=True
)

# Log exceptions generated by django-redis
# ref:https://github.com/jazzband/django-redis/tree/5.4.0#log-ignored-exceptions
DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS = env.bool(
"DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS", True
)

CACHES = {
"default": {
"BACKEND": "django.core.cache.backends.locmem.LocMemCache",
Expand Down Expand Up @@ -675,6 +701,10 @@
"LOCATION": ENVIRONMENT_SEGMENTS_CACHE_LOCATION,
"TIMEOUT": ENVIRONMENT_SEGMENTS_CACHE_SECONDS,
},
USER_THROTTLE_CACHE_NAME: {
"BACKEND": USER_THROTTLE_CACHE_BACKEND,
"LOCATION": USER_THROTTLE_CACHE_LOCATION,
},
}

TRENCH_AUTH = {
Expand Down
1 change: 1 addition & 0 deletions api/app/settings/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# We dont want to track tests
ENABLE_TELEMETRY = False
MAX_PROJECTS_IN_FREE_PLAN = 10
REST_FRAMEWORK["DEFAULT_THROTTLE_CLASSES"] = ["core.throttling.UserRateThrottle"]
REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"] = {
"login": "100/min",
"mfa_code": "5/min",
Expand Down
2 changes: 2 additions & 0 deletions api/app_analytics/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class SDKAnalyticsFlags(GenericAPIView):

permission_classes = (EnvironmentKeyPermissions,)
authentication_classes = (EnvironmentKeyAuthentication,)
throttle_classes = []

def get_serializer_class(self):
if getattr(self, "swagger_fake_view", False):
Expand Down Expand Up @@ -116,6 +117,7 @@ class SelfHostedTelemetryAPIView(CreateAPIView):

permission_classes = ()
authentication_classes = ()
throttle_classes = []
serializer_class = TelemetrySerializer


Expand Down
12 changes: 9 additions & 3 deletions api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import boto3
import pytest
from django.contrib.contenttypes.models import ContentType
from django.core.cache import cache
from django.core.cache import caches
from flag_engine.segments.constants import EQUAL
from moto import mock_dynamodb
from mypy_boto3_dynamodb.service_resource import DynamoDBServiceResource, Table
Expand Down Expand Up @@ -350,9 +350,15 @@ def reset_cache():
# https://groups.google.com/g/django-developers/c/zlaPsP13dUY
# TL;DR: Use this if your test interacts with cache since django
# does not clear cache after every test
cache.clear()
# Clear all caches before the test
for cache in caches.all():
cache.clear()

yield
cache.clear()

# Clear all caches after the test
for cache in caches.all():
cache.clear()


@pytest.fixture()
Expand Down
73 changes: 73 additions & 0 deletions api/core/redis_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Temporary module that adds support for Redis Cluster to django-redis by implementing
a connection factory class(`ClusterConnectionFactory`).
This module should be removed once [this](https://github.com/jazzband/django-redis/issues/606)
is resolved.

Usage:
------
Include the following configuration in Django project's settings.py file:

```python
# settings.py

DJANGO_REDIS_CONNECTION_FACTORY = "core.redis_cluster.ClusterConnectionFactory"
"""

import threading
from copy import deepcopy

from django.core.exceptions import ImproperlyConfigured
from django_redis.pool import ConnectionFactory
from redis.cluster import RedisCluster


class ClusterConnectionFactory(ConnectionFactory):
"""A connection factory for redis.cluster.RedisCluster
The cluster client manages connection pools internally, so we don't want to
do it at this level like the base ConnectionFactory does.
"""

# A global cache of URL->client so that within a process, we will reuse a
# single client, and therefore a single set of connection pools.
_clients = {}
_clients_lock = threading.Lock()

def connect(self, url: str) -> RedisCluster:
"""Given a connection url, return a client instance.
Prefer to return from our cache but if we don't yet have one build it
to populate the cache.
"""
if url not in self._clients:
with self._clients_lock:
if url not in self._clients:
params = self.make_connection_params(url)
self._clients[url] = self.get_connection(params)

return self._clients[url]

def get_connection(self, connection_params: dict) -> RedisCluster:
"""
Given connection_params, return a new client instance.
Basic django-redis ConnectionFactory manages a cache of connection
pools and builds a fresh client each time. because the cluster client
manages its own connection pools, we will instead merge the
"connection" and "client" kwargs and throw them all at the client to
sort out.
If we find conflicting client and connection kwargs, we'll raise an
error.
"""
client_cls_kwargs = deepcopy(self.redis_client_cls_kwargs)
# ... and smash 'em together (crashing if there's conflicts)...
for key, value in connection_params.items():
if key in client_cls_kwargs:
raise ImproperlyConfigured(
f"Found '{key}' in both the connection and the client kwargs"
)
client_cls_kwargs[key] = value

# ... and then build and return the client
return RedisCluster(**client_cls_kwargs)

def disconnect(self, connection: RedisCluster):
connection.disconnect_connection_pools()
7 changes: 7 additions & 0 deletions api/core/throttling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from django.conf import settings
from django.core.cache import caches
from rest_framework import throttling


class UserRateThrottle(throttling.UserRateThrottle):
cache = caches[settings.USER_THROTTLE_CACHE_NAME]
1 change: 1 addition & 0 deletions api/environments/identities/traits/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class SDKTraitsDeprecated(SDKAPIView):
# API to handle /api/v1/identities/<identifier>/traits/<trait_key> endpoints
# if Identity or Trait does not exist it will create one, otherwise will fetch existing
serializer_class = TraitSerializerBasic
throttle_classes = []

schema = None

Expand Down
1 change: 1 addition & 0 deletions api/environments/identities/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class SDKIdentitiesDeprecated(SDKAPIView):
# if Identity does not exist it will create one, otherwise will fetch existing

serializer_class = IdentifyWithTraitsSerializer
throttle_classes = []

schema = None

Expand Down
2 changes: 1 addition & 1 deletion api/features/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,8 @@ class SDKFeatureStates(GenericAPIView):
permission_classes = (EnvironmentKeyPermissions,)
authentication_classes = (EnvironmentKeyAuthentication,)
renderer_classes = [JSONRenderer]
throttle_classes = []
pagination_class = None
throttle_classes = []

@swagger_auto_schema(
query_serializer=SDKFeatureStatesQuerySerializer(),
Expand Down
4 changes: 2 additions & 2 deletions api/integrations/sentry/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from django.conf import settings

NON_FUNCTIONAL_ENDPOINTS = ("/health", "")
SDK_ENDPOINTS = (
SDK_ENDPOINTS = {
"/api/v1/flags",
"/api/v1/identities",
"/api/v1/traits",
"/api/v1/traits/bulk",
"/api/v1/environment-document",
"/api/v1/analytics/flags",
)
}


def traces_sampler(ctx):
Expand Down
50 changes: 48 additions & 2 deletions api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ pydantic = "~1.10.9"
pyngo = "~1.6.0"
flagsmith = "^3.4.0"
python-gnupg = "^0.5.1"
django-redis = "^5.4.0"

[tool.poetry.group.auth-controller]
optional = true
Expand Down
12 changes: 12 additions & 0 deletions api/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ def sdk_client(environment_api_key):
return client


@pytest.fixture()
def server_side_sdk_client(
admin_client: APIClient, environment: int, environment_api_key: str
) -> APIClient:
url = reverse("api-v1:environments:api-keys-list", args={environment_api_key})
response = admin_client.post(url, data={"name": "Some key"})

client = APIClient()
client.credentials(HTTP_X_ENVIRONMENT_KEY=response.json()["key"])
return client


@pytest.fixture()
def default_feature_value():
return "default_value"
Expand Down
Loading
Loading