diff --git a/api/app/settings/common.py b/api/app/settings/common.py index bd540e797bd4..5630cf9764cb 100644 --- a/api/app/settings/common.py +++ b/api/app/settings/common.py @@ -1155,6 +1155,8 @@ "SEGMENT_CONDITION_VALUE_LIMIT must be between 0 and 2,000,000 (2MB)." ) +SEGMENT_RULES_CONDITIONS_LIMIT = env.int("SEGMENT_RULES_CONDITIONS_LIMIT", 100) + WEBHOOK_BACKOFF_BASE = env.int("WEBHOOK_BACKOFF_BASE", default=2) WEBHOOK_BACKOFF_RETRIES = env.int("WEBHOOK_BACKOFF_RETRIES", default=3) diff --git a/api/segments/migrations/0021_create_whitelisted_segments.py b/api/segments/migrations/0021_create_whitelisted_segments.py new file mode 100644 index 000000000000..3f6eb6f3f17a --- /dev/null +++ b/api/segments/migrations/0021_create_whitelisted_segments.py @@ -0,0 +1,55 @@ +# Generated by Django 3.2.24 on 2024-02-12 19:48 +from django.apps.registry import Apps +from django.db import migrations, models +import django.db.models.deletion +from django.db.backends.base.schema import BaseDatabaseSchemaEditor +from django.conf import settings +from django.db import connection + + + +def create_whitelisted_segments(apps: Apps, schema_editor: BaseDatabaseSchemaEditor) -> None: + model_class = apps.get_model("segments", "WhitelistedSegment") + sql = f""" + SELECT s.id + FROM segments_segment s + LEFT OUTER JOIN segments_segmentrule sr1 ON s.id = sr1.segment_id + LEFT OUTER JOIN segments_segmentrule sr2 ON sr1.id = sr2.rule_id + LEFT OUTER JOIN segments_condition sc ON sr2.id = sc.rule_id + GROUP BY s.id + HAVING COUNT(*) > {settings.SEGMENT_RULES_CONDITIONS_LIMIT} + ORDER BY COUNT(*) DESC; + """ + + whitelisted_segments = [] + with connection.cursor() as cursor: + cursor.execute(sql) + results = cursor.fetchall() + for result in results: + segment_id = result[0] + whitelisted_segments.append(model_class(segment_id=segment_id)) + + model_class.objects.bulk_create(whitelisted_segments) + + +class Migration(migrations.Migration): + + dependencies = [ + ('segments', '0020_detach_segment_from_project_cascade_delete'), + ] + + operations = [ + migrations.CreateModel( + name='WhitelistedSegment', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created_at', models.DateTimeField(auto_now_add=True, null=True)), + ('updated_at', models.DateTimeField(auto_now=True, null=True)), + ('segment', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, related_name='whitelisted_segment', to='segments.segment')), + ], + ), + migrations.RunPython( + create_whitelisted_segments, + reverse_code=migrations.RunPython.noop, + ) + ] diff --git a/api/segments/models.py b/api/segments/models.py index 23409d460682..dcf996301713 100644 --- a/api/segments/models.py +++ b/api/segments/models.py @@ -202,3 +202,19 @@ def _get_segment(self) -> Segment: def _get_project(self) -> typing.Optional[Project]: return self.rule.get_segment().project + + +class WhitelistedSegment(models.Model): + """ + In order to grandfather in existing segments, these models represent segments + that do not conform to the SEGMENT_RULES_CONDITIONS_LIMIT and may have + more than the typically allowed number of segment rules and conditions. + """ + + segment = models.OneToOneField( + Segment, + on_delete=models.CASCADE, + related_name="whitelisted_segment", + ) + created_at = models.DateTimeField(null=True, auto_now_add=True) + updated_at = models.DateTimeField(null=True, auto_now=True) diff --git a/api/segments/serializers.py b/api/segments/serializers.py index 28966e65fd7c..4fc9597dbdaa 100644 --- a/api/segments/serializers.py +++ b/api/segments/serializers.py @@ -1,5 +1,6 @@ import typing +from django.conf import settings from flag_engine.segments.constants import PERCENTAGE_SPLIT from rest_framework import serializers from rest_framework.exceptions import ValidationError @@ -58,6 +59,7 @@ def create(self, validated_data): self.validate_project_segment_limit(project) rules_data = validated_data.pop("rules", []) + self.validate_segment_rules_conditions_limit(rules_data) # create segment with nested rules and conditions segment = Segment.objects.create(**validated_data) @@ -66,6 +68,15 @@ def create(self, validated_data): ) return segment + def update(self, instance, validated_data): + # use the initial data since we need the ids included to determine which to update & which to create + rules_data = self.initial_data.pop("rules", []) + self.validate_segment_rules_conditions_limit(rules_data) + self._update_segment_rules(rules_data, segment=instance) + # remove rules from validated data to prevent error trying to create segment with nested rules + del validated_data["rules"] + return super().update(instance, validated_data) + def validate_project_segment_limit(self, project: Project) -> None: if project.segments.count() >= project.max_segments_allowed: raise ValidationError( @@ -74,13 +85,38 @@ def validate_project_segment_limit(self, project: Project) -> None: } ) - def update(self, instance, validated_data): - # use the initial data since we need the ids included to determine which to update & which to create - rules_data = self.initial_data.pop("rules", []) - self._update_segment_rules(rules_data, segment=instance) - # remove rules from validated data to prevent error trying to create segment with nested rules - del validated_data["rules"] - return super().update(instance, validated_data) + def validate_segment_rules_conditions_limit( + self, rules_data: dict[str, object] + ) -> None: + if self.instance and getattr(self.instance, "whitelisted_segment", None): + return + + count = self._calculate_condition_count(rules_data) + + if count > settings.SEGMENT_RULES_CONDITIONS_LIMIT: + raise ValidationError( + { + "segment": f"The segment has {count} conditions, which exceeds the maximum " + f"condition count of {settings.SEGMENT_RULES_CONDITIONS_LIMIT}." + } + ) + + def _calculate_condition_count( + self, + rules_data: dict[str, object], + ) -> None: + count: int = 0 + + for rule_data in rules_data: + child_rules = rule_data.get("rules", []) + if child_rules: + count += self._calculate_condition_count(child_rules) + conditions = rule_data.get("conditions", []) + for condition in conditions: + if condition.get("delete", False) is True: + continue + count += 1 + return count def _update_segment_rules(self, rules_data, segment=None): """ diff --git a/api/tests/unit/segments/test_migrations.py b/api/tests/unit/segments/test_migrations.py new file mode 100644 index 000000000000..f2b181a81165 --- /dev/null +++ b/api/tests/unit/segments/test_migrations.py @@ -0,0 +1,104 @@ +import pytest +from django.conf import settings as test_settings +from django_test_migrations.migrator import Migrator +from flag_engine.segments import constants +from pytest_django.fixtures import SettingsWrapper + + +@pytest.mark.skipif( + test_settings.SKIP_MIGRATION_TESTS is True, + reason="Skip migration tests to speed up tests where necessary", +) +def test_create_whitelisted_segments_migration( + migrator: Migrator, + settings: SettingsWrapper, +) -> None: + # Given - The migration state is at 0020 (before the migration we want to test). + old_state = migrator.apply_initial_migration( + ("segments", "0020_detach_segment_from_project_cascade_delete") + ) + + Organisation = old_state.apps.get_model("organisations", "Organisation") + Project = old_state.apps.get_model("projects", "Project") + SegmentRule = old_state.apps.get_model("segments", "SegmentRule") + Segment = old_state.apps.get_model("segments", "Segment") + Condition = old_state.apps.get_model("segments", "Condition") + + # Set the limit lower to allow for a faster test. + settings.SEGMENT_RULES_CONDITIONS_LIMIT = 3 + + # Next, create the setup data. + organisation = Organisation.objects.create(name="Big Corp Incorporated") + project = Project.objects.create(name="Huge Project", organisation=organisation) + + segment_1 = Segment.objects.create(name="Segment1", project=project) + segment_2 = Segment.objects.create(name="Segment1", project=project) + segment_rule_1 = SegmentRule.objects.create( + segment=segment_1, + type="ALL", + ) + + # Subnested segment rules. + segment_rule_2 = SegmentRule.objects.create( + rule=segment_rule_1, + type="ALL", + ) + segment_rule_3 = SegmentRule.objects.create( + rule=segment_rule_1, + type="ALL", + ) + + # Lonely segment rules for pass criteria for segment_2. + segment_rule_4 = SegmentRule.objects.create( + segment=segment_2, + type="ALL", + ) + segment_rule_5 = SegmentRule.objects.create( + rule=segment_rule_4, + type="ALL", + ) + + Condition.objects.create( + operator=constants.EQUAL, + property="age", + value="21", + rule=segment_rule_2, + ) + Condition.objects.create( + operator=constants.GREATER_THAN, + property="height", + value="210", + rule=segment_rule_2, + ) + Condition.objects.create( + operator=constants.GREATER_THAN, + property="waist", + value="36", + rule=segment_rule_3, + ) + Condition.objects.create( + operator=constants.LESS_THAN, + property="shoes", + value="12", + rule=segment_rule_3, + ) + + # Sole criteria for segment_2 conditions. + Condition.objects.create( + operator=constants.LESS_THAN, + property="toy_count", + value="7", + rule=segment_rule_5, + ) + + # When we run the migration. + new_state = migrator.apply_tested_migration( + ("segments", "0021_create_whitelisted_segments") + ) + + # Then the first segment is in the whitelist while the second is not. + NewSegment = new_state.apps.get_model("segments", "Segment") + new_segment_1 = NewSegment.objects.get(id=segment_1.id) + new_segment_2 = NewSegment.objects.get(id=segment_2.id) + assert new_segment_1.whitelisted_segment + assert getattr(new_segment_2, "whitelisted_segment", None) is None diff --git a/api/tests/unit/segments/test_unit_segments_views.py b/api/tests/unit/segments/test_unit_segments_views.py index e0126b386a1e..54df65ad9cc8 100644 --- a/api/tests/unit/segments/test_unit_segments_views.py +++ b/api/tests/unit/segments/test_unit_segments_views.py @@ -5,14 +5,17 @@ from django.contrib.auth import get_user_model from django.urls import reverse from flag_engine.segments.constants import EQUAL +from pytest_django.fixtures import SettingsWrapper from pytest_lazyfixture import lazy_fixture from rest_framework import status +from rest_framework.test import APIClient from audit.models import AuditLog from audit.related_object_type import RelatedObjectType from environments.models import Environment from features.models import Feature -from segments.models import Condition, Segment, SegmentRule +from projects.models import Project +from segments.models import Condition, Segment, SegmentRule, WhitelistedSegment from util.mappers import map_identity_to_identity_document User = get_user_model() @@ -129,7 +132,7 @@ def test_create_segments_reaching_max_limit(project, client, settings): ], } - # Now, let's create the firs segment + # Now, let's create the first segment res = client.post(url, data=json.dumps(data), content_type="application/json") assert res.status_code == status.HTTP_201_CREATED @@ -574,3 +577,208 @@ def test_update_segment_delete_existing_rule(project, client, segment, segment_r assert response.status_code == status.HTTP_200_OK assert segment_rule.conditions.count() == 0 + + +def test_update_segment_obeys_max_conditions( + project: Project, + admin_client: APIClient, + segment: Segment, + segment_rule: SegmentRule, + settings: SettingsWrapper, +) -> None: + # Given + url = reverse( + "api-v1:projects:project-segments-detail", args=[project.id, segment.id] + ) + nested_rule = SegmentRule.objects.create( + rule=segment_rule, type=SegmentRule.ANY_RULE + ) + existing_condition = Condition.objects.create( + rule=nested_rule, property="foo", operator=EQUAL, value="bar" + ) + + # Reduce value for test debugging. + settings.SEGMENT_RULES_CONDITIONS_LIMIT = 10 + new_condition_property = "prop_" + new_condition_value = "red" + new_conditions = [] + for i in range(settings.SEGMENT_RULES_CONDITIONS_LIMIT): + new_conditions.append( + { + "property": f"{new_condition_property}{i}", + "operator": EQUAL, + "value": new_condition_value, + } + ) + + data = { + "name": segment.name, + "project": project.id, + "rules": [ + { + "id": segment_rule.id, + "type": segment_rule.type, + "rules": [ + { + "id": nested_rule.id, + "type": nested_rule.type, + "rules": [], + "conditions": [ + { + "id": existing_condition.id, + "property": existing_condition.property, + "operator": existing_condition.operator, + "value": existing_condition.value, + }, + *new_conditions, + ], + } + ], + "conditions": [], + } + ], + } + + # When + response = admin_client.put( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json() == { + "segment": "The segment has 11 conditions, which exceeds the maximum condition count of 10." + } + + nested_rule.refresh_from_db() + assert nested_rule.conditions.count() == 1 + + +def test_update_segment_evades_max_conditions_when_whitelisted( + project: Project, + admin_client: APIClient, + segment: Segment, + segment_rule: SegmentRule, + settings: SettingsWrapper, +) -> None: + # Given + url = reverse( + "api-v1:projects:project-segments-detail", args=[project.id, segment.id] + ) + nested_rule = SegmentRule.objects.create( + rule=segment_rule, type=SegmentRule.ANY_RULE + ) + existing_condition = Condition.objects.create( + rule=nested_rule, property="foo", operator=EQUAL, value="bar" + ) + + # Create the whitelist to stop the validation. + WhitelistedSegment.objects.create(segment=segment) + + # Reduce value for test debugging. + settings.SEGMENT_RULES_CONDITIONS_LIMIT = 10 + new_condition_property = "prop_" + new_condition_value = "red" + new_conditions = [] + for i in range(settings.SEGMENT_RULES_CONDITIONS_LIMIT): + new_conditions.append( + { + "property": f"{new_condition_property}{i}", + "operator": EQUAL, + "value": new_condition_value, + } + ) + + data = { + "name": segment.name, + "project": project.id, + "rules": [ + { + "id": segment_rule.id, + "type": segment_rule.type, + "rules": [ + { + "id": nested_rule.id, + "type": nested_rule.type, + "rules": [], + "conditions": [ + { + "id": existing_condition.id, + "property": existing_condition.property, + "operator": existing_condition.operator, + "value": existing_condition.value, + }, + *new_conditions, + ], + } + ], + "conditions": [], + } + ], + } + + # When + response = admin_client.put( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_200_OK + nested_rule.refresh_from_db() + assert nested_rule.conditions.count() == 11 + + +def test_create_segment_obeys_max_conditions( + project: Project, + admin_client: APIClient, + settings: SettingsWrapper, +) -> None: + # Given + url = reverse("api-v1:projects:project-segments-list", args=[project.id]) + + # Reduce value for test debugging. + settings.SEGMENT_RULES_CONDITIONS_LIMIT = 10 + new_condition_property = "prop_" + new_condition_value = "red" + new_conditions = [] + for i in range(settings.SEGMENT_RULES_CONDITIONS_LIMIT + 1): + new_conditions.append( + { + "property": f"{new_condition_property}{i}", + "operator": EQUAL, + "value": new_condition_value, + } + ) + + data = { + "name": "segment_name", + "project": project.id, + "rules": [ + { + "conditions": [], + "type": "ALL", + "rules": [ + { + "type": "ANY", + "rules": [], + "conditions": [ + *new_conditions, + ], + } + ], + } + ], + } + + # When + response = admin_client.post( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json() == { + "segment": "The segment has 11 conditions, which exceeds the maximum condition count of 10." + } + + assert Segment.objects.count() == 0 diff --git a/docs/docs/system-administration/system-limits.md b/docs/docs/system-administration/system-limits.md index 5eab5a684d70..65f64976c6cf 100644 --- a/docs/docs/system-administration/system-limits.md +++ b/docs/docs/system-administration/system-limits.md @@ -22,6 +22,7 @@ In order to ensure consistent performance, Flagsmith has the following limitatio - **400** Features per Project - **100** Segments per Project - **100** Segment Overrides per Environment +- **100** Segment Rules Conditions ### Entity Data Elements