Skip to content

Commit

Permalink
fix: Limit segment rules and conditions (#3397)
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Elwell <[email protected]>
  • Loading branch information
zachaysan and matthewelwell authored Feb 14, 2024
1 parent c666d29 commit c89e96e
Show file tree
Hide file tree
Showing 7 changed files with 431 additions and 9 deletions.
2 changes: 2 additions & 0 deletions api/app/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,8 @@
"SEGMENT_CONDITION_VALUE_LIMIT must be between 0 and 2,000,000 (2MB)."
)

SEGMENT_RULES_CONDITIONS_LIMIT = env.int("SEGMENT_RULES_CONDITIONS_LIMIT", 100)

WEBHOOK_BACKOFF_BASE = env.int("WEBHOOK_BACKOFF_BASE", default=2)
WEBHOOK_BACKOFF_RETRIES = env.int("WEBHOOK_BACKOFF_RETRIES", default=3)

Expand Down
55 changes: 55 additions & 0 deletions api/segments/migrations/0021_create_whitelisted_segments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Generated by Django 3.2.24 on 2024-02-12 19:48
from django.apps.registry import Apps
from django.db import migrations, models
import django.db.models.deletion
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.conf import settings
from django.db import connection



def create_whitelisted_segments(apps: Apps, schema_editor: BaseDatabaseSchemaEditor) -> None:
model_class = apps.get_model("segments", "WhitelistedSegment")
sql = f"""
SELECT s.id
FROM segments_segment s
LEFT OUTER JOIN segments_segmentrule sr1 ON s.id = sr1.segment_id
LEFT OUTER JOIN segments_segmentrule sr2 ON sr1.id = sr2.rule_id
LEFT OUTER JOIN segments_condition sc ON sr2.id = sc.rule_id
GROUP BY s.id
HAVING COUNT(*) > {settings.SEGMENT_RULES_CONDITIONS_LIMIT}
ORDER BY COUNT(*) DESC;
"""

whitelisted_segments = []
with connection.cursor() as cursor:
cursor.execute(sql)
results = cursor.fetchall()
for result in results:
segment_id = result[0]
whitelisted_segments.append(model_class(segment_id=segment_id))

model_class.objects.bulk_create(whitelisted_segments)


class Migration(migrations.Migration):

dependencies = [
('segments', '0020_detach_segment_from_project_cascade_delete'),
]

operations = [
migrations.CreateModel(
name='WhitelistedSegment',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('created_at', models.DateTimeField(auto_now_add=True, null=True)),
('updated_at', models.DateTimeField(auto_now=True, null=True)),
('segment', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, related_name='whitelisted_segment', to='segments.segment')),
],
),
migrations.RunPython(
create_whitelisted_segments,
reverse_code=migrations.RunPython.noop,
)
]
16 changes: 16 additions & 0 deletions api/segments/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,19 @@ def _get_segment(self) -> Segment:

def _get_project(self) -> typing.Optional[Project]:
return self.rule.get_segment().project


class WhitelistedSegment(models.Model):
"""
In order to grandfather in existing segments, these models represent segments
that do not conform to the SEGMENT_RULES_CONDITIONS_LIMIT and may have
more than the typically allowed number of segment rules and conditions.
"""

segment = models.OneToOneField(
Segment,
on_delete=models.CASCADE,
related_name="whitelisted_segment",
)
created_at = models.DateTimeField(null=True, auto_now_add=True)
updated_at = models.DateTimeField(null=True, auto_now=True)
50 changes: 43 additions & 7 deletions api/segments/serializers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing

from django.conf import settings
from flag_engine.segments.constants import PERCENTAGE_SPLIT
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
Expand Down Expand Up @@ -58,6 +59,7 @@ def create(self, validated_data):
self.validate_project_segment_limit(project)

rules_data = validated_data.pop("rules", [])
self.validate_segment_rules_conditions_limit(rules_data)

# create segment with nested rules and conditions
segment = Segment.objects.create(**validated_data)
Expand All @@ -66,6 +68,15 @@ def create(self, validated_data):
)
return segment

def update(self, instance, validated_data):
# use the initial data since we need the ids included to determine which to update & which to create
rules_data = self.initial_data.pop("rules", [])
self.validate_segment_rules_conditions_limit(rules_data)
self._update_segment_rules(rules_data, segment=instance)
# remove rules from validated data to prevent error trying to create segment with nested rules
del validated_data["rules"]
return super().update(instance, validated_data)

def validate_project_segment_limit(self, project: Project) -> None:
if project.segments.count() >= project.max_segments_allowed:
raise ValidationError(
Expand All @@ -74,13 +85,38 @@ def validate_project_segment_limit(self, project: Project) -> None:
}
)

def update(self, instance, validated_data):
# use the initial data since we need the ids included to determine which to update & which to create
rules_data = self.initial_data.pop("rules", [])
self._update_segment_rules(rules_data, segment=instance)
# remove rules from validated data to prevent error trying to create segment with nested rules
del validated_data["rules"]
return super().update(instance, validated_data)
def validate_segment_rules_conditions_limit(
self, rules_data: dict[str, object]
) -> None:
if self.instance and getattr(self.instance, "whitelisted_segment", None):
return

count = self._calculate_condition_count(rules_data)

if count > settings.SEGMENT_RULES_CONDITIONS_LIMIT:
raise ValidationError(
{
"segment": f"The segment has {count} conditions, which exceeds the maximum "
f"condition count of {settings.SEGMENT_RULES_CONDITIONS_LIMIT}."
}
)

def _calculate_condition_count(
self,
rules_data: dict[str, object],
) -> None:
count: int = 0

for rule_data in rules_data:
child_rules = rule_data.get("rules", [])
if child_rules:
count += self._calculate_condition_count(child_rules)
conditions = rule_data.get("conditions", [])
for condition in conditions:
if condition.get("delete", False) is True:
continue
count += 1
return count

def _update_segment_rules(self, rules_data, segment=None):
"""
Expand Down
104 changes: 104 additions & 0 deletions api/tests/unit/segments/test_migrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pytest
from django.conf import settings as test_settings
from django_test_migrations.migrator import Migrator
from flag_engine.segments import constants
from pytest_django.fixtures import SettingsWrapper


@pytest.mark.skipif(
test_settings.SKIP_MIGRATION_TESTS is True,
reason="Skip migration tests to speed up tests where necessary",
)
def test_create_whitelisted_segments_migration(
migrator: Migrator,
settings: SettingsWrapper,
) -> None:
# Given - The migration state is at 0020 (before the migration we want to test).
old_state = migrator.apply_initial_migration(
("segments", "0020_detach_segment_from_project_cascade_delete")
)

Organisation = old_state.apps.get_model("organisations", "Organisation")
Project = old_state.apps.get_model("projects", "Project")
SegmentRule = old_state.apps.get_model("segments", "SegmentRule")
Segment = old_state.apps.get_model("segments", "Segment")
Condition = old_state.apps.get_model("segments", "Condition")

# Set the limit lower to allow for a faster test.
settings.SEGMENT_RULES_CONDITIONS_LIMIT = 3

# Next, create the setup data.
organisation = Organisation.objects.create(name="Big Corp Incorporated")
project = Project.objects.create(name="Huge Project", organisation=organisation)

segment_1 = Segment.objects.create(name="Segment1", project=project)
segment_2 = Segment.objects.create(name="Segment1", project=project)
segment_rule_1 = SegmentRule.objects.create(
segment=segment_1,
type="ALL",
)

# Subnested segment rules.
segment_rule_2 = SegmentRule.objects.create(
rule=segment_rule_1,
type="ALL",
)
segment_rule_3 = SegmentRule.objects.create(
rule=segment_rule_1,
type="ALL",
)

# Lonely segment rules for pass criteria for segment_2.
segment_rule_4 = SegmentRule.objects.create(
segment=segment_2,
type="ALL",
)
segment_rule_5 = SegmentRule.objects.create(
rule=segment_rule_4,
type="ALL",
)

Condition.objects.create(
operator=constants.EQUAL,
property="age",
value="21",
rule=segment_rule_2,
)
Condition.objects.create(
operator=constants.GREATER_THAN,
property="height",
value="210",
rule=segment_rule_2,
)
Condition.objects.create(
operator=constants.GREATER_THAN,
property="waist",
value="36",
rule=segment_rule_3,
)
Condition.objects.create(
operator=constants.LESS_THAN,
property="shoes",
value="12",
rule=segment_rule_3,
)

# Sole criteria for segment_2 conditions.
Condition.objects.create(
operator=constants.LESS_THAN,
property="toy_count",
value="7",
rule=segment_rule_5,
)

# When we run the migration.
new_state = migrator.apply_tested_migration(
("segments", "0021_create_whitelisted_segments")
)

# Then the first segment is in the whitelist while the second is not.
NewSegment = new_state.apps.get_model("segments", "Segment")
new_segment_1 = NewSegment.objects.get(id=segment_1.id)
new_segment_2 = NewSegment.objects.get(id=segment_2.id)
assert new_segment_1.whitelisted_segment
assert getattr(new_segment_2, "whitelisted_segment", None) is None
Loading

0 comments on commit c89e96e

Please sign in to comment.