Skip to content

Commit

Permalink
feat: Add support for replicas and cross region replicas (#3300)
Browse files Browse the repository at this point in the history
  • Loading branch information
zachaysan authored Feb 9, 2024
1 parent 42634ec commit bda59f5
Show file tree
Hide file tree
Showing 6 changed files with 349 additions and 19 deletions.
2 changes: 2 additions & 0 deletions api/app/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class ImproperlyConfiguredError(RuntimeError):
pass
93 changes: 91 additions & 2 deletions api/app/routers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,79 @@
import logging
import random
from enum import Enum

from django.conf import settings
from django.core.cache import cache
from django.db import connections

from .exceptions import ImproperlyConfiguredError

logger = logging.getLogger(__name__)

CONNECTION_CHECK_CACHE_TTL = 2


class ReplicaReadStrategy(Enum):
DISTRIBUTED = "DISTRIBUTED"
SEQUENTIAL = "SEQUENTIAL"


def connection_check(database: str) -> bool:
try:
conn = connections.create_connection(database)
conn.connect()
usable = conn.is_usable()
if not usable:
logger.warning(
f"Unable to access database {database} during connection check"
)
except Exception:
usable = False
logger.error(
"Encountered exception during connection",
exc_info=True,
)

if usable:
cache.set(
f"db_connection_active.{database}", "online", CONNECTION_CHECK_CACHE_TTL
)
else:
cache.set(
f"db_connection_active.{database}", "offline", CONNECTION_CHECK_CACHE_TTL
)

return usable


class PrimaryReplicaRouter:
def db_for_read(self, model, **hints):
if settings.NUM_DB_REPLICAS == 0:
return "default"
return random.choice(
[f"replica_{i}" for i in range(1, settings.NUM_DB_REPLICAS + 1)]

replicas = [f"replica_{i}" for i in range(1, settings.NUM_DB_REPLICAS + 1)]
replica = self._get_replica(replicas)
if replica:
# This return is the most likely as replicas should be
# online and properly functioning.
return replica

# Since no replicas are available, fall back to the cross
# region replicas which have worse availability.
cross_region_replicas = [
f"cross_region_replica_{i}"
for i in range(1, settings.NUM_CROSS_REGION_DB_REPLICAS + 1)
]

cross_region_replica = self._get_replica(cross_region_replicas)
if cross_region_replica:
return cross_region_replica

# No available replicas, so fallback to the default.
logger.warning(
"Unable to serve any available replicas, falling back to default database"
)
return "default"

def db_for_write(self, model, **hints):
return "default"
Expand All @@ -22,6 +86,10 @@ def allow_relation(self, obj1, obj2, **hints):
db_set = {
"default",
*[f"replica_{i}" for i in range(1, settings.NUM_DB_REPLICAS + 1)],
*[
f"cross_region_replica_{i}"
for i in range(1, settings.NUM_CROSS_REGION_DB_REPLICAS + 1)
],
}
if obj1._state.db in db_set and obj2._state.db in db_set:
return True
Expand All @@ -30,6 +98,27 @@ def allow_relation(self, obj1, obj2, **hints):
def allow_migrate(self, db, app_label, model_name=None, **hints):
return db == "default"

def _get_replica(self, replicas: list[str]) -> None | str:
while replicas:
if settings.REPLICA_READ_STRATEGY == ReplicaReadStrategy.DISTRIBUTED.value:
database = random.choice(replicas)
elif settings.REPLICA_READ_STRATEGY == ReplicaReadStrategy.SEQUENTIAL.value:
database = replicas[0]
else:
raise ImproperlyConfiguredError(
f"Unknown REPLICA_READ_STRATEGY {settings.REPLICA_READ_STRATEGY}"
)

replicas.remove(database)
db_cache = cache.get(f"db_connection_active.{database}")
if db_cache == "online":
return database
if db_cache == "offline":
continue

if connection_check(database):
return database


class AnalyticsRouter:
route_app_labels = ["app_analytics"]
Expand Down
28 changes: 28 additions & 0 deletions api/app/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from django.core.management.utils import get_random_secret_key
from environs import Env

from app.routers import ReplicaReadStrategy
from task_processor.task_run_method import TaskRunMethod

env = Env()
Expand Down Expand Up @@ -166,6 +167,7 @@

DATABASE_ROUTERS = ["app.routers.PrimaryReplicaRouter"]
NUM_DB_REPLICAS = 0
NUM_CROSS_REGION_DB_REPLICAS = 0
# Allows collectstatic to run without a database, mainly for Docker builds to collectstatic at build time
if "DATABASE_URL" in os.environ:
DATABASES = {
Expand All @@ -178,11 +180,37 @@
"REPLICA_DATABASE_URLS", default=[], delimiter=REPLICA_DATABASE_URLS_DELIMITER
)
NUM_DB_REPLICAS = len(REPLICA_DATABASE_URLS)

# Cross region replica databases are used as fallbacks if the
# primary replica set becomes unavailable.
CROSS_REGION_REPLICA_DATABASE_URLS_DELIMITER = env(
"CROSS_REGION_REPLICA_DATABASE_URLS_DELIMITER", ","
)
CROSS_REGION_REPLICA_DATABASE_URLS = env.list(
"CROSS_REGION_REPLICA_DATABASE_URLS",
default=[],
delimiter=CROSS_REGION_REPLICA_DATABASE_URLS_DELIMITER,
)
NUM_CROSS_REGION_DB_REPLICAS = len(CROSS_REGION_REPLICA_DATABASE_URLS)

# DISTRIBUTED spreads the load out across replicas while
# SEQUENTIAL only falls back once the first replica connection is faulty
REPLICA_READ_STRATEGY = env.enum(
"REPLICA_READ_STRATEGY",
type=ReplicaReadStrategy,
default=ReplicaReadStrategy.DISTRIBUTED.value,
)

for i, db_url in enumerate(REPLICA_DATABASE_URLS, start=1):
DATABASES[f"replica_{i}"] = dj_database_url.parse(
db_url, conn_max_age=DJANGO_DB_CONN_MAX_AGE
)

for i, db_url in enumerate(CROSS_REGION_REPLICA_DATABASE_URLS, start=1):
DATABASES[f"cross_region_replica_{i}"] = dj_database_url.parse(
db_url, conn_max_age=DJANGO_DB_CONN_MAX_AGE
)

if "ANALYTICS_DATABASE_URL" in os.environ:
DATABASES["analytics"] = dj_database_url.parse(
env("ANALYTICS_DATABASE_URL"), conn_max_age=DJANGO_DB_CONN_MAX_AGE
Expand Down
196 changes: 196 additions & 0 deletions api/tests/unit/app/test_unit_app_routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
from pytest_django.fixtures import SettingsWrapper
from pytest_mock import MockerFixture

from app.routers import (
PrimaryReplicaRouter,
ReplicaReadStrategy,
connection_check,
)
from users.models import FFAdminUser


def test_connection_check_to_default_database(db: None, reset_cache: None) -> None:
# When
connection_check_works = connection_check("default")

# Then
assert connection_check_works is True


def test_replica_router_db_for_read_with_one_offline_replica(
db: None,
settings: SettingsWrapper,
mocker: MockerFixture,
reset_cache: None,
) -> None:
# Given
settings.NUM_DB_REPLICAS = 4

# Set unused cross regional db for testing non-inclusion.
settings.NUM_CROSS_REGION_DB_REPLICAS = 2
settings.REPLICA_READ_STRATEGY = ReplicaReadStrategy.DISTRIBUTED.value

conn_patch = mocker.MagicMock()
conn_patch.is_usable.side_effect = (False, True)
create_connection_patch = mocker.patch(
"app.routers.connections.create_connection", return_value=conn_patch
)

router = PrimaryReplicaRouter()

# When
result = router.db_for_read(FFAdminUser)

# Then
# Read strategy DISTRIBUTED is random, so just this is a check
# against loading the primary or one of the cross region replicas
assert result.startswith("replica_")

# Check that the number of replica call counts is as expected.
conn_call_count = 2
assert create_connection_patch.call_count == conn_call_count
assert conn_patch.is_usable.call_count == conn_call_count


def test_replica_router_db_for_read_with_local_offline_replicas(
db: None,
settings: SettingsWrapper,
mocker: MockerFixture,
reset_cache: None,
) -> None:
# Given
settings.NUM_DB_REPLICAS = 4

# Use cross regional db for fallback after replicas.
settings.NUM_CROSS_REGION_DB_REPLICAS = 2
settings.REPLICA_READ_STRATEGY = ReplicaReadStrategy.DISTRIBUTED.value

conn_patch = mocker.MagicMock()

# All four replicas go offline and so does one of the cross
# regional replica as well, before finally the last cross
# region replica is finally connected to.
conn_patch.is_usable.side_effect = (
False,
False,
False,
False,
False,
True,
)
create_connection_patch = mocker.patch(
"app.routers.connections.create_connection", return_value=conn_patch
)

router = PrimaryReplicaRouter()

# When
result = router.db_for_read(FFAdminUser)

# Then
# Read strategy DISTRIBUTED is random, so just this is a check
# against loading the primary or one of the cross region replicas
assert result.startswith("cross_region_replica_")

# Check that the number of replica call counts is as expected.
conn_call_count = 6
assert create_connection_patch.call_count == conn_call_count
assert conn_patch.is_usable.call_count == conn_call_count


def test_replica_router_db_for_read_with_all_offline_replicas(
db: None,
settings: SettingsWrapper,
mocker: MockerFixture,
reset_cache: None,
) -> None:
# Given
settings.NUM_DB_REPLICAS = 4
settings.NUM_CROSS_REGION_DB_REPLICAS = 2
settings.REPLICA_READ_STRATEGY = ReplicaReadStrategy.DISTRIBUTED.value

conn_patch = mocker.MagicMock()

# All replicas go offline.
conn_patch.is_usable.return_value = False
create_connection_patch = mocker.patch(
"app.routers.connections.create_connection", return_value=conn_patch
)

router = PrimaryReplicaRouter()

# When
result = router.db_for_read(FFAdminUser)

# Then
# Fallback to primary database if all replicas are offline.
assert result == "default"

# Check that the number of replica call counts is as expected.
conn_call_count = 6
assert create_connection_patch.call_count == conn_call_count
assert conn_patch.is_usable.call_count == conn_call_count


def test_replica_router_db_with_sequential_read(
db: None,
settings: SettingsWrapper,
mocker: MockerFixture,
reset_cache: None,
) -> None:
# Given
settings.NUM_DB_REPLICAS = 100
settings.NUM_CROSS_REGION_DB_REPLICAS = 2
settings.REPLICA_READ_STRATEGY = ReplicaReadStrategy.SEQUENTIAL.value

conn_patch = mocker.MagicMock()

# First replica is offline, so must fall back to second one.
conn_patch.is_usable.side_effect = (False, True)
create_connection_patch = mocker.patch(
"app.routers.connections.create_connection", return_value=conn_patch
)

router = PrimaryReplicaRouter()

# When
result = router.db_for_read(FFAdminUser)

# Then
# Fallback from first replica to second one.
assert result == "replica_2"

# Check that the number of replica call counts is as expected.
conn_call_count = 2
assert create_connection_patch.call_count == conn_call_count
assert conn_patch.is_usable.call_count == conn_call_count


def test_replica_router_db_no_replicas(
db: None,
settings: SettingsWrapper,
mocker: MockerFixture,
reset_cache: None,
) -> None:
# Given
settings.NUM_DB_REPLICAS = 0
settings.NUM_CROSS_REGION_DB_REPLICAS = 0

conn_patch = mocker.MagicMock()

# All replicas should be ignored.
create_connection_patch = mocker.patch(
"app.routers.connections.create_connection", return_value=conn_patch
)

router = PrimaryReplicaRouter()

# When
result = router.db_for_read(FFAdminUser)

# Then
# Should always use primary database.
assert result == "default"
conn_call_count = 0
assert create_connection_patch.call_count == conn_call_count
assert conn_patch.is_usable.call_count == conn_call_count
Loading

0 comments on commit bda59f5

Please sign in to comment.